Browse Source

!15865 fix fasterrcnn fail in pynative

From: @chujinjin
Reviewed-by: @linqingke,@kisnwang
Signed-off-by: @linqingke
pull/15865/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
9fbfc63de9
2 changed files with 7 additions and 5 deletions
  1. +6
    -4
      mindspore/ops/operations/array_ops.py
  2. +1
    -1
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py

+ 6
- 4
mindspore/ops/operations/array_ops.py View File

@@ -311,14 +311,16 @@ class Cast(PrimitiveWithInfer):

def check_elim(self, x, dtype):
if isinstance(x, (Tensor, numbers.Number, Parameter)):
if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
if isinstance(x, Parameter):
data = x.data
if data.dtype == dtype:
return (True, x)
if isinstance(x, Tensor) and x.dtype == dtype:
x = Tensor(x)
x.set_cast_dtype()
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
return (False, None)

def __infer__(self, x, t):


+ 1
- 1
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py View File

@@ -143,7 +143,7 @@ class Rcnn(nn.Cell):

if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type)
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))

loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)


Loading…
Cancel
Save