diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 8a6b2f8e45..df91c1500d 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -311,14 +311,16 @@ class Cast(PrimitiveWithInfer): def check_elim(self, x, dtype): if isinstance(x, (Tensor, numbers.Number, Parameter)): - if isinstance(x, Tensor) and x.dtype == dtype: - return (True, x) - if isinstance(x, numbers.Number): - return (True, Tensor(x, dtype=dtype)) if isinstance(x, Parameter): data = x.data if data.dtype == dtype: return (True, x) + if isinstance(x, Tensor) and x.dtype == dtype: + x = Tensor(x) + x.set_cast_dtype() + return (True, x) + if isinstance(x, numbers.Number): + return (True, Tensor(x, dtype=dtype)) return (False, None) def __infer__(self, x, t): diff --git a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py index 4eade1f188..6d08b5173a 100644 --- a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py +++ b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py @@ -143,7 +143,7 @@ class Rcnn(nn.Cell): if self.training: bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels - labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type) + labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value) bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)