From: @dessyang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,7 +21,7 @@ import numpy as np | |||||
| from pycocotools.coco import COCO | from pycocotools.coco import COCO | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed | |||||
| from mindspore.common import set_seed, Parameter | |||||
| from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | ||||
| from src.config import config | from src.config import config | ||||
| @@ -34,16 +34,22 @@ parser = argparse.ArgumentParser(description="FasterRcnn evaluation") | |||||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | ||||
| parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.") | parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.") | ||||
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="device where the code will be implemented, default is Ascend") | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): | def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): | ||||
| """FasterRcnn evaluation.""" | """FasterRcnn evaluation.""" | ||||
| ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False) | ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False) | ||||
| net = Faster_Rcnn_Resnet50(config) | net = Faster_Rcnn_Resnet50(config) | ||||
| param_dict = load_checkpoint(ckpt_path) | param_dict = load_checkpoint(ckpt_path) | ||||
| if args_opt.device_target == "GPU": | |||||
| for key, value in param_dict.items(): | |||||
| tensor = value.asnumpy().astype(np.float32) | |||||
| param_dict[key] = Parameter(tensor, key) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| @@ -0,0 +1,44 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "sh run_distribute_train_gpu.sh DEVICE_NUM PRETRAINED_PATH" | |||||
| echo "for example: sh run_distribute_train_gpu.sh 8 /path/pretrain.ckpt" | |||||
| echo "It is better to use absolute path." | |||||
| echo "==============================================================================================================" | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [PRETRAINED_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| rm -rf run_distribute_train | |||||
| mkdir run_distribute_train | |||||
| cp -rf ../src/ ../train.py ./run_distribute_train | |||||
| cd run_distribute_train || exit | |||||
| export RANK_SIZE=$1 | |||||
| PRETRAINED_PATH=$2 | |||||
| echo "start training on $RANK_SIZE devices" | |||||
| mpirun -n $RANK_SIZE \ | |||||
| python train.py \ | |||||
| --run_distribute=True \ | |||||
| --device_target="GPU" \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --pre_trained=$PRETRAINED_PATH > log 2>&1 & | |||||
| @@ -0,0 +1,64 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval_gpu.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| echo $PATH1 | |||||
| echo $PATH2 | |||||
| if [ ! -f $PATH1 ] | |||||
| then | |||||
| echo "error: ANN_FILE=$PATH1 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| export DEVICE_NUM=1 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start eval for device $DEVICE_ID" | |||||
| python eval.py --device_target="GPU" --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 &> log & | |||||
| cd .. | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -19,11 +19,12 @@ import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore import context | |||||
| class BboxAssignSample(nn.Cell): | class BboxAssignSample(nn.Cell): | ||||
| """ | """ | ||||
| Bbox assigner and sampler defination. | |||||
| Bbox assigner and sampler definition. | |||||
| Args: | Args: | ||||
| config (dict): Config. | config (dict): Config. | ||||
| @@ -45,12 +46,15 @@ class BboxAssignSample(nn.Cell): | |||||
| def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | ||||
| super(BboxAssignSample, self).__init__() | super(BboxAssignSample, self).__init__() | ||||
| cfg = config | cfg = config | ||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| self.dtype = np.float16 if _mode_16 else np.float32 | |||||
| self.ms_type = mstype.float16 if _mode_16 else mstype.float32 | |||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16) | |||||
| self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16) | |||||
| self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16) | |||||
| self.zero_thr = Tensor(0.0, mstype.float16) | |||||
| self.neg_iou_thr = Tensor(cfg.neg_iou_thr, self.ms_type) | |||||
| self.pos_iou_thr = Tensor(cfg.pos_iou_thr, self.ms_type) | |||||
| self.min_pos_iou = Tensor(cfg.min_pos_iou, self.ms_type) | |||||
| self.zero_thr = Tensor(0.0, self.ms_type) | |||||
| self.num_bboxes = num_bboxes | self.num_bboxes = num_bboxes | ||||
| self.num_gts = cfg.num_gts | self.num_gts = cfg.num_gts | ||||
| @@ -92,9 +96,9 @@ class BboxAssignSample(nn.Cell): | |||||
| self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | ||||
| self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | ||||
| self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||||
| self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||||
| self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||||
| self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype)) | |||||
| self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=self.dtype)) | |||||
| self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=self.dtype)) | |||||
| def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | ||||
| @@ -129,7 +133,7 @@ class BboxAssignSample(nn.Cell): | |||||
| pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | ||||
| pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||||
| pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type) | |||||
| pos_check_valid = self.sum_inds(pos_check_valid, -1) | pos_check_valid = self.sum_inds(pos_check_valid, -1) | ||||
| valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | ||||
| pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | ||||
| @@ -140,7 +144,7 @@ class BboxAssignSample(nn.Cell): | |||||
| neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | ||||
| num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16) | |||||
| num_pos = self.cast(self.logicalnot(valid_pos_index), self.ms_type) | |||||
| num_pos = self.sum_inds(num_pos, -1) | num_pos = self.sum_inds(num_pos, -1) | ||||
| unvalid_pos_index = self.less(self.range_pos_size, num_pos) | unvalid_pos_index = self.less(self.range_pos_size, num_pos) | ||||
| valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -19,11 +19,12 @@ import mindspore.nn as nn | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore import context | |||||
| class BboxAssignSampleForRcnn(nn.Cell): | class BboxAssignSampleForRcnn(nn.Cell): | ||||
| """ | """ | ||||
| Bbox assigner and sampler defination. | |||||
| Bbox assigner and sampler definition. | |||||
| Args: | Args: | ||||
| config (dict): Config. | config (dict): Config. | ||||
| @@ -45,6 +46,9 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | ||||
| super(BboxAssignSampleForRcnn, self).__init__() | super(BboxAssignSampleForRcnn, self).__init__() | ||||
| cfg = config | cfg = config | ||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| self.dtype = np.float16 if _mode_16 else np.float32 | |||||
| self.ms_type = mstype.float16 if _mode_16 else mstype.float32 | |||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.neg_iou_thr = cfg.neg_iou_thr_stage2 | self.neg_iou_thr = cfg.neg_iou_thr_stage2 | ||||
| self.pos_iou_thr = cfg.pos_iou_thr_stage2 | self.pos_iou_thr = cfg.pos_iou_thr_stage2 | ||||
| @@ -83,8 +87,8 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| self.tile = P.Tile() | self.tile = P.Tile() | ||||
| # Check | # Check | ||||
| self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||||
| self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||||
| self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=self.dtype)) | |||||
| self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=self.dtype)) | |||||
| # Init tensor | # Init tensor | ||||
| self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | ||||
| @@ -94,18 +98,18 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | ||||
| self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32)) | self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32)) | ||||
| self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||||
| self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype)) | |||||
| self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | ||||
| self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16)) | |||||
| self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=self.dtype)) | |||||
| self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) | self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) | ||||
| self.reshape_shape_pos = (self.num_expected_pos, 1) | self.reshape_shape_pos = (self.num_expected_pos, 1) | ||||
| self.reshape_shape_neg = (self.num_expected_neg, 1) | self.reshape_shape_neg = (self.num_expected_neg, 1) | ||||
| self.scalar_zero = Tensor(0.0, dtype=mstype.float16) | |||||
| self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16) | |||||
| self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16) | |||||
| self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16) | |||||
| self.scalar_zero = Tensor(0.0, dtype=self.ms_type) | |||||
| self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=self.ms_type) | |||||
| self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=self.ms_type) | |||||
| self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=self.ms_type) | |||||
| def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | ||||
| gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | ||||
| @@ -149,12 +153,12 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| # Get pos index | # Get pos index | ||||
| pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | ||||
| pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||||
| pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type) | |||||
| pos_check_valid = self.sum_inds(pos_check_valid, -1) | pos_check_valid = self.sum_inds(pos_check_valid, -1) | ||||
| valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | ||||
| pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | ||||
| num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1) | |||||
| num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), self.ms_type), -1) | |||||
| valid_pos_index = self.cast(valid_pos_index, mstype.int32) | valid_pos_index = self.cast(valid_pos_index, mstype.int32) | ||||
| pos_index = self.reshape(pos_index, self.reshape_shape_pos) | pos_index = self.reshape(pos_index, self.reshape_shape_pos) | ||||
| valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) | valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -20,6 +20,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore import context | |||||
| from .resnet50 import ResNetFea, ResidualBlockUsing | from .resnet50 import ResNetFea, ResidualBlockUsing | ||||
| from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn | from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn | ||||
| from .fpn_neck import FeatPyramidNeck | from .fpn_neck import FeatPyramidNeck | ||||
| @@ -50,6 +51,9 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, config): | def __init__(self, config): | ||||
| super(Faster_Rcnn_Resnet50, self).__init__() | super(Faster_Rcnn_Resnet50, self).__init__() | ||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| self.dtype = np.float16 if _mode_16 else np.float32 | |||||
| self.ms_type = mstype.float16 if _mode_16 else mstype.float32 | |||||
| self.train_batch_size = config.batch_size | self.train_batch_size = config.batch_size | ||||
| self.num_classes = config.num_classes | self.num_classes = config.num_classes | ||||
| self.anchor_scales = config.anchor_scales | self.anchor_scales = config.anchor_scales | ||||
| @@ -157,7 +161,7 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| self.rpn_max_num = config.rpn_max_num | self.rpn_max_num = config.rpn_max_num | ||||
| self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16)) | |||||
| self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.dtype)) | |||||
| self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool) | self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool) | ||||
| self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool) | self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool) | ||||
| self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask, | self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask, | ||||
| @@ -165,10 +169,10 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask, | self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask, | ||||
| self.ones_mask, self.ones_mask, self.zeros_mask), axis=1)) | self.ones_mask, self.ones_mask, self.zeros_mask), axis=1)) | ||||
| self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr) | |||||
| self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0) | |||||
| self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1) | |||||
| self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr) | |||||
| self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_score_thr) | |||||
| self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * 0) | |||||
| self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.dtype) * -1) | |||||
| self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_iou_thr) | |||||
| self.test_max_per_img = config.test_max_per_img | self.test_max_per_img = config.test_max_per_img | ||||
| self.nms_test = P.NMSWithMask(config.test_iou_thr) | self.nms_test = P.NMSWithMask(config.test_iou_thr) | ||||
| self.softmax = P.Softmax(axis=1) | self.softmax = P.Softmax(axis=1) | ||||
| @@ -183,9 +187,9 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| # Init tensor | # Init tensor | ||||
| roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, | roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, | ||||
| dtype=np.float16) for i in range(self.train_batch_size)] | |||||
| dtype=self.dtype) for i in range(self.train_batch_size)] | |||||
| roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \ | |||||
| roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.dtype) \ | |||||
| for i in range(self.test_batch_size)] | for i in range(self.test_batch_size)] | ||||
| self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) | self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) | ||||
| @@ -276,7 +280,7 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| self.cast(x[3], mstype.float32)) | self.cast(x[3], mstype.float32)) | ||||
| roi_feats = self.cast(roi_feats, mstype.float16) | |||||
| roi_feats = self.cast(roi_feats, self.ms_type) | |||||
| rcnn_masks = self.concat(mask_tuple) | rcnn_masks = self.concat(mask_tuple) | ||||
| rcnn_masks = F.stop_gradient(rcnn_masks) | rcnn_masks = F.stop_gradient(rcnn_masks) | ||||
| rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_)) | rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_)) | ||||
| @@ -420,7 +424,7 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| for i in range(num_levels): | for i in range(num_levels): | ||||
| anchors = self.anchor_generators[i].grid_anchors( | anchors = self.anchor_generators[i].grid_anchors( | ||||
| featmap_sizes[i], self.anchor_strides[i]) | featmap_sizes[i], self.anchor_strides[i]) | ||||
| multi_level_anchors += (Tensor(anchors.astype(np.float16)),) | |||||
| multi_level_anchors += (Tensor(anchors.astype(self.dtype)),) | |||||
| return multi_level_anchors | return multi_level_anchors | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -22,16 +22,20 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| def bias_init_zeros(shape): | def bias_init_zeros(shape): | ||||
| """Bias init method.""" | """Bias init method.""" | ||||
| return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16)) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16)) | |||||
| return Tensor(np.array(np.zeros(shape).astype(np.float32))) | |||||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | ||||
| """Conv2D wrapper.""" | """Conv2D wrapper.""" | ||||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | shape = (out_channels, in_channels, kernel_size, kernel_size) | ||||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||||
| else: | |||||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor() | |||||
| shape_bias = (out_channels,) | shape_bias = (out_channels,) | ||||
| biass = bias_init_zeros(shape_bias) | biass = bias_init_zeros(shape_bias) | ||||
| return nn.Conv2d(in_channels, out_channels, | return nn.Conv2d(in_channels, out_channels, | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -22,9 +22,6 @@ from mindspore import Tensor | |||||
| from mindspore import context | from mindspore import context | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Proposal(nn.Cell): | class Proposal(nn.Cell): | ||||
| """ | """ | ||||
| Proposal subnet. | Proposal subnet. | ||||
| @@ -106,7 +103,11 @@ class Proposal(nn.Cell): | |||||
| self.tile = P.Tile() | self.tile = P.Tile() | ||||
| self.set_train_local(config, training=True) | self.set_train_local(config, training=True) | ||||
| self.multi_10 = Tensor(10.0, mstype.float16) | |||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| self.dtype = np.float16 if _mode_16 else np.float32 | |||||
| self.ms_type = mstype.float16 if _mode_16 else mstype.float32 | |||||
| self.multi_10 = Tensor(10.0, self.ms_type) | |||||
| def set_train_local(self, config, training=True): | def set_train_local(self, config, training=True): | ||||
| """Set training flag.""" | """Set training flag.""" | ||||
| @@ -133,7 +134,10 @@ class Proposal(nn.Cell): | |||||
| self.topKv2 = P.TopK(sorted=True) | self.topKv2 = P.TopK(sorted=True) | ||||
| self.topK_shape_stage2 = (self.max_num, 1) | self.topK_shape_stage2 = (self.max_num, 1) | ||||
| self.min_float_num = -65536.0 | self.min_float_num = -65536.0 | ||||
| self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) | |||||
| else: | |||||
| self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float32)) | |||||
| def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): | def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): | ||||
| proposals_tuple = () | proposals_tuple = () | ||||
| @@ -164,16 +168,16 @@ class Proposal(nn.Cell): | |||||
| rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) | rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) | ||||
| rpn_cls_score = self.activation(rpn_cls_score) | rpn_cls_score = self.activation(rpn_cls_score) | ||||
| rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16) | |||||
| rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), self.ms_type) | |||||
| rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16) | |||||
| rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), self.ms_type) | |||||
| scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx]) | scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx]) | ||||
| topk_inds = self.reshape(topk_inds, self.topK_shape[idx]) | topk_inds = self.reshape(topk_inds, self.topK_shape[idx]) | ||||
| bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) | bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) | ||||
| anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16) | |||||
| anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), self.ms_type) | |||||
| proposals_decode = self.decode(anchors_sorted, bboxes_sorted) | proposals_decode = self.decode(anchors_sorted, bboxes_sorted) | ||||
| @@ -188,7 +192,7 @@ class Proposal(nn.Cell): | |||||
| _, _, _, _, scores = self.split(proposals) | _, _, _, _, scores = self.split(proposals) | ||||
| scores = self.squeeze(scores) | scores = self.squeeze(scores) | ||||
| topk_mask = self.cast(self.topK_mask, mstype.float16) | |||||
| topk_mask = self.cast(self.topK_mask, self.ms_type) | |||||
| scores_using = self.select(masks, scores, topk_mask) | scores_using = self.select(masks, scores, topk_mask) | ||||
| _, topk_inds = self.topKv2(scores_using, self.max_num) | _, topk_inds = self.topKv2(scores_using, self.max_num) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,15 +21,19 @@ from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore import context | |||||
| class DenseNoTranpose(nn.Cell): | class DenseNoTranpose(nn.Cell): | ||||
| """Dense method""" | """Dense method""" | ||||
| def __init__(self, input_channels, output_channels, weight_init): | def __init__(self, input_channels, output_channels, weight_init): | ||||
| super(DenseNoTranpose, self).__init__() | super(DenseNoTranpose, self).__init__() | ||||
| self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16)) | |||||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16)) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16)) | |||||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16)) | |||||
| else: | |||||
| self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32)) | |||||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32)) | |||||
| self.matmul = P.MatMul(transpose_b=False) | self.matmul = P.MatMul(transpose_b=False) | ||||
| self.bias_add = P.BiasAdd() | self.bias_add = P.BiasAdd() | ||||
| @@ -68,8 +72,11 @@ class Rcnn(nn.Cell): | |||||
| ): | ): | ||||
| super(Rcnn, self).__init__() | super(Rcnn, self).__init__() | ||||
| cfg = config | cfg = config | ||||
| self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16)) | |||||
| self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16)) | |||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| self.dtype = np.float16 if _mode_16 else np.float32 | |||||
| self.ms_type = mstype.float16 if _mode_16 else mstype.float32 | |||||
| self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(self.dtype)) | |||||
| self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(self.dtype)) | |||||
| self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels | self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels | ||||
| self.target_means = target_means | self.target_means = target_means | ||||
| self.target_stds = target_stds | self.target_stds = target_stds | ||||
| @@ -79,16 +86,16 @@ class Rcnn(nn.Cell): | |||||
| self.test_batch_size = cfg.test_batch_size | self.test_batch_size = cfg.test_batch_size | ||||
| shape_0 = (self.rcnn_fc_out_channels, representation_size) | shape_0 = (self.rcnn_fc_out_channels, representation_size) | ||||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16) | |||||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=self.ms_type).to_tensor() | |||||
| shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels) | shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels) | ||||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16) | |||||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=self.ms_type).to_tensor() | |||||
| self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0) | self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0) | ||||
| self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1) | self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1) | ||||
| cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1], | cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1], | ||||
| dtype=mstype.float16) | |||||
| dtype=self.ms_type).to_tensor() | |||||
| reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1], | reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1], | ||||
| dtype=mstype.float16) | |||||
| dtype=self.ms_type).to_tensor() | |||||
| self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight) | self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight) | ||||
| self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight) | self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight) | ||||
| @@ -110,13 +117,13 @@ class Rcnn(nn.Cell): | |||||
| self.on_value = Tensor(1.0, mstype.float32) | self.on_value = Tensor(1.0, mstype.float32) | ||||
| self.off_value = Tensor(0.0, mstype.float32) | self.off_value = Tensor(0.0, mstype.float32) | ||||
| self.value = Tensor(1.0, mstype.float16) | |||||
| self.value = Tensor(1.0, self.ms_type) | |||||
| self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size | self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size | ||||
| rmv_first = np.ones((self.num_bboxes, self.num_classes)) | rmv_first = np.ones((self.num_bboxes, self.num_classes)) | ||||
| rmv_first[:, 0] = np.zeros((self.num_bboxes,)) | rmv_first[:, 0] = np.zeros((self.num_bboxes,)) | ||||
| self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16)) | |||||
| self.rmv_first_tensor = Tensor(rmv_first.astype(self.dtype)) | |||||
| self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size | self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size | ||||
| @@ -134,7 +141,7 @@ class Rcnn(nn.Cell): | |||||
| if self.training: | if self.training: | ||||
| bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels | bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels | ||||
| labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16) | |||||
| labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type) | |||||
| bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) | bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) | ||||
| loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) | loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) | ||||
| @@ -149,12 +156,12 @@ class Rcnn(nn.Cell): | |||||
| loss_print = () | loss_print = () | ||||
| loss_cls, _ = self.loss_cls(cls_score, labels) | loss_cls, _ = self.loss_cls(cls_score, labels) | ||||
| weights = self.cast(weights, mstype.float16) | |||||
| weights = self.cast(weights, self.ms_type) | |||||
| loss_cls = loss_cls * weights | loss_cls = loss_cls * weights | ||||
| loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) | loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) | ||||
| bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), | bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), | ||||
| mstype.float16) | |||||
| self.ms_type) | |||||
| bbox_weights = bbox_weights * self.rmv_first_tensor | bbox_weights = bbox_weights * self.rmv_first_tensor | ||||
| pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4)) | pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4)) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -22,12 +22,11 @@ from mindspore.ops import functional as F | |||||
| from mindspore import context | from mindspore import context | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| def weight_init_ones(shape): | def weight_init_ones(shape): | ||||
| """Weight init.""" | """Weight init.""" | ||||
| return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16)) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16)) | |||||
| return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01)) | |||||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | ||||
| @@ -41,11 +40,12 @@ def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mod | |||||
| def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True): | def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True): | ||||
| """Batchnorm2D wrapper.""" | """Batchnorm2D wrapper.""" | ||||
| gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) | |||||
| beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) | |||||
| moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) | |||||
| moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) | |||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| dtype = np.float16 if _mode_16 else np.float32 | |||||
| gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype)) | |||||
| beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype)) | |||||
| moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype)) | |||||
| moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(dtype)) | |||||
| return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, | return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, | ||||
| beta_init=beta_init, moving_mean_init=moving_mean_init, | beta_init=beta_init, moving_mean_init=moving_mean_init, | ||||
| moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) | moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -17,7 +17,7 @@ import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import Tensor | |||||
| from mindspore import Tensor, context | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from .bbox_assign_sample import BboxAssignSample | from .bbox_assign_sample import BboxAssignSample | ||||
| @@ -100,6 +100,9 @@ class RPN(nn.Cell): | |||||
| cls_out_channels): | cls_out_channels): | ||||
| super(RPN, self).__init__() | super(RPN, self).__init__() | ||||
| cfg_rpn = config | cfg_rpn = config | ||||
| _mode_16 = bool(context.get_context("device_target") == "Ascend") | |||||
| self.dtype = np.float16 if _mode_16 else np.float32 | |||||
| self.ms_type = mstype.float16 if _mode_16 else mstype.float32 | |||||
| self.num_bboxes = cfg_rpn.num_bboxes | self.num_bboxes = cfg_rpn.num_bboxes | ||||
| self.slice_index = () | self.slice_index = () | ||||
| self.feature_anchor_shape = () | self.feature_anchor_shape = () | ||||
| @@ -114,7 +117,7 @@ class RPN(nn.Cell): | |||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.test_batch_size = cfg_rpn.test_batch_size | self.test_batch_size = cfg_rpn.test_batch_size | ||||
| self.num_layers = 5 | self.num_layers = 5 | ||||
| self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) | |||||
| self.real_ratio = Tensor(np.ones((1, 1)).astype(self.dtype)) | |||||
| self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels, | self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels, | ||||
| num_anchors, cls_out_channels)) | num_anchors, cls_out_channels)) | ||||
| @@ -123,15 +126,15 @@ class RPN(nn.Cell): | |||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.concat = P.Concat(axis=0) | self.concat = P.Concat(axis=0) | ||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) | |||||
| self.placeh1 = Tensor(np.ones((1,)).astype(self.dtype)) | |||||
| self.trans_shape = (0, 2, 3, 1) | self.trans_shape = (0, 2, 3, 1) | ||||
| self.reshape_shape_reg = (-1, 4) | self.reshape_shape_reg = (-1, 4) | ||||
| self.reshape_shape_cls = (-1,) | self.reshape_shape_cls = (-1,) | ||||
| self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) | |||||
| self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) | |||||
| self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) | |||||
| self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(self.dtype)) | |||||
| self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(self.dtype)) | |||||
| self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(self.dtype)) | |||||
| self.num_bboxes = cfg_rpn.num_bboxes | self.num_bboxes = cfg_rpn.num_bboxes | ||||
| self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) | self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) | ||||
| self.CheckValid = P.CheckValid() | self.CheckValid = P.CheckValid() | ||||
| @@ -142,9 +145,9 @@ class RPN(nn.Cell): | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.tile = P.Tile() | self.tile = P.Tile() | ||||
| self.zeros_like = P.ZerosLike() | self.zeros_like = P.ZerosLike() | ||||
| self.loss = Tensor(np.zeros((1,)).astype(np.float16)) | |||||
| self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||||
| self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||||
| self.loss = Tensor(np.zeros((1,)).astype(self.dtype)) | |||||
| self.clsloss = Tensor(np.zeros((1,)).astype(self.dtype)) | |||||
| self.regloss = Tensor(np.zeros((1,)).astype(self.dtype)) | |||||
| def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): | def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): | ||||
| """ | """ | ||||
| @@ -164,18 +167,18 @@ class RPN(nn.Cell): | |||||
| shp_weight_conv = (feat_channels, in_channels, 3, 3) | shp_weight_conv = (feat_channels, in_channels, 3, 3) | ||||
| shp_bias_conv = (feat_channels,) | shp_bias_conv = (feat_channels,) | ||||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16) | |||||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16) | |||||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=self.ms_type).to_tensor() | |||||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=self.ms_type).to_tensor() | |||||
| shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | ||||
| shp_bias_cls = (num_anchors * cls_out_channels,) | shp_bias_cls = (num_anchors * cls_out_channels,) | ||||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16) | |||||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16) | |||||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=self.ms_type).to_tensor() | |||||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=self.ms_type).to_tensor() | |||||
| shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | ||||
| shp_bias_reg = (num_anchors * 4,) | shp_bias_reg = (num_anchors * 4,) | ||||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16) | |||||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16) | |||||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=self.ms_type).to_tensor() | |||||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor() | |||||
| for i in range(num_layers): | for i in range(num_layers): | ||||
| rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | ||||
| @@ -248,9 +251,9 @@ class RPN(nn.Cell): | |||||
| mstype.bool_), | mstype.bool_), | ||||
| anchor_using_list, gt_valids_i) | anchor_using_list, gt_valids_i) | ||||
| bbox_weight = self.cast(bbox_weight, mstype.float16) | |||||
| label = self.cast(label, mstype.float16) | |||||
| label_weight = self.cast(label_weight, mstype.float16) | |||||
| bbox_weight = self.cast(bbox_weight, self.ms_type) | |||||
| label = self.cast(label, self.ms_type) | |||||
| label_weight = self.cast(label_weight, self.ms_type) | |||||
| for j in range(self.num_layers): | for j in range(self.num_layers): | ||||
| begin = self.slice_index[j] | begin = self.slice_index[j] | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -113,8 +113,6 @@ config = ed({ | |||||
| # LR | # LR | ||||
| "base_lr": 0.02, | "base_lr": 0.02, | ||||
| "base_step": 58633, | |||||
| "total_epoch": 13, | |||||
| "warmup_step": 500, | "warmup_step": 500, | ||||
| "warmup_ratio": 1/3.0, | "warmup_ratio": 1/3.0, | ||||
| "sgd_step": [8, 11], | "sgd_step": [8, 11], | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,6 +21,7 @@ import numpy as np | |||||
| from numpy import random | from numpy import random | ||||
| import mmcv | import mmcv | ||||
| from mindspore import context | |||||
| import mindspore.dataset as de | import mindspore.dataset as de | ||||
| import mindspore.dataset.vision.c_transforms as C | import mindspore.dataset.vision.c_transforms as C | ||||
| from mindspore.mindrecord import FileWriter | from mindspore.mindrecord import FileWriter | ||||
| @@ -213,7 +214,7 @@ def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
| def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | ||||
| """imnormalize operation for image""" | """imnormalize operation for image""" | ||||
| img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True) | |||||
| img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True) | |||||
| img_data = img_data.astype(np.float32) | img_data = img_data.astype(np.float32) | ||||
| return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | ||||
| @@ -232,9 +233,14 @@ def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
| def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num): | def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num): | ||||
| """transpose operation for image""" | """transpose operation for image""" | ||||
| img_data = img.transpose(2, 0, 1).copy() | img_data = img.transpose(2, 0, 1).copy() | ||||
| img_data = img_data.astype(np.float16) | |||||
| img_shape = img_shape.astype(np.float16) | |||||
| gt_bboxes = gt_bboxes.astype(np.float16) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| img_data = img_data.astype(np.float16) | |||||
| img_shape = img_shape.astype(np.float16) | |||||
| gt_bboxes = gt_bboxes.astype(np.float16) | |||||
| else: | |||||
| img_data = img_data.astype(np.float32) | |||||
| img_shape = img_shape.astype(np.float32) | |||||
| gt_bboxes = gt_bboxes.astype(np.float32) | |||||
| gt_label = gt_label.astype(np.int32) | gt_label = gt_label.astype(np.int32) | ||||
| gt_num = gt_num.astype(np.bool) | gt_num = gt_num.astype(np.bool) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -25,12 +25,10 @@ def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||||
| learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | ||||
| return learning_rate | return learning_rate | ||||
| def dynamic_lr(config, rank_size=1): | |||||
| def dynamic_lr(config, steps_per_epoch): | |||||
| """dynamic learning rate generator""" | """dynamic learning rate generator""" | ||||
| base_lr = config.base_lr | base_lr = config.base_lr | ||||
| base_step = (config.base_step // rank_size) + rank_size | |||||
| total_steps = int(base_step * config.total_epoch) | |||||
| total_steps = steps_per_epoch * config.epoch_size | |||||
| warmup_steps = int(config.warmup_step) | warmup_steps = int(config.warmup_step) | ||||
| lr = [] | lr = [] | ||||
| for i in range(total_steps): | for i in range(total_steps): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -20,7 +20,7 @@ import mindspore.nn as nn | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore import ParameterTuple | |||||
| from mindspore import ParameterTuple, context | |||||
| from mindspore.train.callback import Callback | from mindspore.train.callback import Callback | ||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | ||||
| @@ -167,7 +167,10 @@ class TrainOneStepCell(nn.Cell): | |||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.grad = C.GradOperation(get_by_list=True, | self.grad = C.GradOperation(get_by_list=True, | ||||
| sens_param=True) | sens_param=True) | ||||
| self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) | |||||
| else: | |||||
| self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32)) | |||||
| self.reduce_flag = reduce_flag | self.reduce_flag = reduce_flag | ||||
| if reduce_flag: | if reduce_flag: | ||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -19,10 +19,11 @@ import os | |||||
| import time | import time | ||||
| import argparse | import argparse | ||||
| import ast | import ast | ||||
| import numpy as np | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore import context, Tensor | |||||
| from mindspore.communication.management import init | |||||
| from mindspore import context, Tensor, Parameter | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor | from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| @@ -42,20 +43,30 @@ parser = argparse.ArgumentParser(description="FasterRcnn training") | |||||
| parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") | parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") | ||||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.") | parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.") | ||||
| parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") | parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="device where the code will be implemented, default is Ascend") | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | ||||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.") | parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.") | ||||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") | parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.run_distribute: | if args_opt.run_distribute: | ||||
| rank = args_opt.rank_id | |||||
| device_num = args_opt.device_num | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| init() | |||||
| if args_opt.device_target == "Ascend": | |||||
| rank = args_opt.rank_id | |||||
| device_num = args_opt.device_num | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| context.reset_auto_parallel_context() | |||||
| rank = get_rank() | |||||
| device_num = get_group_size() | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| else: | else: | ||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| @@ -116,10 +127,14 @@ if __name__ == '__main__': | |||||
| for item in list(param_dict.keys()): | for item in list(param_dict.keys()): | ||||
| if not item.startswith('backbone'): | if not item.startswith('backbone'): | ||||
| param_dict.pop(item) | param_dict.pop(item) | ||||
| if args_opt.device_target == "GPU": | |||||
| for key, value in param_dict.items(): | |||||
| tensor = value.asnumpy().astype(np.float32) | |||||
| param_dict[key] = Parameter(tensor, key) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| loss = LossNet() | loss = LossNet() | ||||
| lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32) | |||||
| lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32) | |||||
| opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | ||||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | weight_decay=config.weight_decay, loss_scale=config.loss_scale) | ||||
| @@ -141,4 +156,4 @@ if __name__ == '__main__': | |||||
| cb += [ckpoint_cb] | cb += [ckpoint_cb] | ||||
| model = Model(net) | model = Model(net) | ||||
| model.train(config.epoch_size, dataset, callbacks=cb) | |||||
| model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False) | |||||