From 31ee29e7d548e42553b4d79f7095975e60c140f5 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Tue, 6 Apr 2021 21:39:41 +0800 Subject: [PATCH] fix fastrcnn for pynative --- mindspore/ops/operations/array_ops.py | 2 ++ model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 201c71ca28..13bfebd85c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -321,6 +321,8 @@ class Cast(PrimitiveWithInfer): def check_elim(self, x, dtype): if isinstance(x, (Tensor, numbers.Number, Parameter)): if isinstance(x, Tensor) and x.dtype == dtype: + x = Tensor(x) + x.set_cast_dtype() return (True, x) if isinstance(x, numbers.Number): return (True, Tensor(x, dtype=dtype)) diff --git a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py index 4eade1f188..6d08b5173a 100644 --- a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py +++ b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py @@ -143,7 +143,7 @@ class Rcnn(nn.Cell): if self.training: bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels - labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type) + labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value) bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)