| @@ -19,6 +19,7 @@ import argparse | |||||
| import time | import time | ||||
| import numpy as np | import numpy as np | ||||
| from pycocotools.coco import COCO | from pycocotools.coco import COCO | ||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed, Parameter | from mindspore.common import set_seed, Parameter | ||||
| @@ -51,7 +52,11 @@ def fasterrcnn_eval(dataset_path, ckpt_path, ann_file): | |||||
| tensor = value.asnumpy().astype(np.float32) | tensor = value.asnumpy().astype(np.float32) | ||||
| param_dict[key] = Parameter(tensor, key) | param_dict[key] = Parameter(tensor, key) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" | |||||
| if device_type == "Ascend": | |||||
| net.to_float(mstype.float16) | |||||
| eval_iter = 0 | eval_iter = 0 | ||||
| total = ds.get_dataset_size() | total = ds.get_dataset_size() | ||||
| @@ -16,6 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| @@ -144,6 +145,7 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| # Init tensor | # Init tensor | ||||
| self.init_tensor(config) | self.init_tensor(config) | ||||
| self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" | |||||
| def roi_init(self, config): | def roi_init(self, config): | ||||
| self.roi_align = SingleRoIExtractor(config, | self.roi_align = SingleRoIExtractor(config, | ||||
| @@ -267,6 +269,8 @@ class Faster_Rcnn_Resnet50(nn.Cell): | |||||
| bboxes_all = self.concat(bboxes_tuple) | bboxes_all = self.concat(bboxes_tuple) | ||||
| else: | else: | ||||
| bboxes_all = bboxes_tuple[0] | bboxes_all = bboxes_tuple[0] | ||||
| if self.device_type == "Ascend": | |||||
| bboxes_all = self.cast(bboxes_all, mstype.float16) | |||||
| rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) | rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) | ||||
| rois = self.cast(rois, mstype.float32) | rois = self.cast(rois, mstype.float32) | ||||
| @@ -40,7 +40,7 @@ class DenseNoTranpose(nn.Cell): | |||||
| if self.device_type == "Ascend": | if self.device_type == "Ascend": | ||||
| x = self.cast(x, mstype.float16) | x = self.cast(x, mstype.float16) | ||||
| weight = self.cast(self.weight, mstype.float16) | weight = self.cast(self.weight, mstype.float16) | ||||
| output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias) | |||||
| output = self.bias_add(self.matmul(x, weight), self.bias) | |||||
| else: | else: | ||||
| output = self.bias_add(self.matmul(x, self.weight), self.bias) | output = self.bias_add(self.matmul(x, self.weight), self.bias) | ||||
| return output | return output | ||||
| @@ -16,7 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore import Tensor | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| @@ -102,6 +102,7 @@ class RPN(nn.Cell): | |||||
| cfg_rpn = config | cfg_rpn = config | ||||
| self.dtype = np.float32 | self.dtype = np.float32 | ||||
| self.ms_type = mstype.float32 | self.ms_type = mstype.float32 | ||||
| self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" | |||||
| self.num_bboxes = cfg_rpn.num_bboxes | self.num_bboxes = cfg_rpn.num_bboxes | ||||
| self.slice_index = () | self.slice_index = () | ||||
| self.feature_anchor_shape = () | self.feature_anchor_shape = () | ||||
| @@ -180,9 +181,12 @@ class RPN(nn.Cell): | |||||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor() | bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor() | ||||
| for i in range(num_layers): | for i in range(num_layers): | ||||
| rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||||
| rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||||
| weight_conv, bias_conv, weight_cls, \ | weight_conv, bias_conv, weight_cls, \ | ||||
| bias_cls, weight_reg, bias_reg)) | |||||
| bias_cls, weight_reg, bias_reg) | |||||
| if self.device_type == "Ascend": | |||||
| rpn_reg_cls_block.to_float(mstype.float16) | |||||
| rpn_layer.append(rpn_reg_cls_block) | |||||
| for i in range(1, num_layers): | for i in range(1, num_layers): | ||||
| rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight | rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight | ||||
| @@ -250,6 +254,7 @@ class RPN(nn.Cell): | |||||
| mstype.bool_), | mstype.bool_), | ||||
| anchor_using_list, gt_valids_i) | anchor_using_list, gt_valids_i) | ||||
| bbox_target = self.cast(bbox_target, self.ms_type) | |||||
| bbox_weight = self.cast(bbox_weight, self.ms_type) | bbox_weight = self.cast(bbox_weight, self.ms_type) | ||||
| label = self.cast(label, self.ms_type) | label = self.cast(label, self.ms_type) | ||||
| label_weight = self.cast(label_weight, self.ms_type) | label_weight = self.cast(label_weight, self.ms_type) | ||||
| @@ -286,8 +291,8 @@ class RPN(nn.Cell): | |||||
| label_ = F.stop_gradient(label_with_batchsize) | label_ = F.stop_gradient(label_with_batchsize) | ||||
| label_weight_ = F.stop_gradient(label_weight_with_batchsize) | label_weight_ = F.stop_gradient(label_weight_with_batchsize) | ||||
| cls_score_i = rpn_cls_score[i] | |||||
| reg_score_i = rpn_bbox_pred[i] | |||||
| cls_score_i = self.cast(rpn_cls_score[i], self.ms_type) | |||||
| reg_score_i = self.cast(rpn_bbox_pred[i], self.ms_type) | |||||
| loss_cls = self.loss_cls(cls_score_i, label_) | loss_cls = self.loss_cls(cls_score_i, label_) | ||||
| loss_cls_item = loss_cls * label_weight_ | loss_cls_item = loss_cls * label_weight_ | ||||
| @@ -152,6 +152,10 @@ if __name__ == '__main__': | |||||
| param_dict[key] = Parameter(tensor, key) | param_dict[key] = Parameter(tensor, key) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" | |||||
| if device_type == "Ascend": | |||||
| net.to_float(mstype.float16) | |||||
| loss = LossNet() | loss = LossNet() | ||||
| lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32) | lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32) | ||||