From: @chujinjin Reviewed-by: @linqingke,@kisnwang Signed-off-by: @linqingkepull/15865/MERGE
| @@ -311,14 +311,16 @@ class Cast(PrimitiveWithInfer): | |||||
| def check_elim(self, x, dtype): | def check_elim(self, x, dtype): | ||||
| if isinstance(x, (Tensor, numbers.Number, Parameter)): | if isinstance(x, (Tensor, numbers.Number, Parameter)): | ||||
| if isinstance(x, Tensor) and x.dtype == dtype: | |||||
| return (True, x) | |||||
| if isinstance(x, numbers.Number): | |||||
| return (True, Tensor(x, dtype=dtype)) | |||||
| if isinstance(x, Parameter): | if isinstance(x, Parameter): | ||||
| data = x.data | data = x.data | ||||
| if data.dtype == dtype: | if data.dtype == dtype: | ||||
| return (True, x) | return (True, x) | ||||
| if isinstance(x, Tensor) and x.dtype == dtype: | |||||
| x = Tensor(x) | |||||
| x.set_cast_dtype() | |||||
| return (True, x) | |||||
| if isinstance(x, numbers.Number): | |||||
| return (True, Tensor(x, dtype=dtype)) | |||||
| return (False, None) | return (False, None) | ||||
| def __infer__(self, x, t): | def __infer__(self, x, t): | ||||
| @@ -143,7 +143,7 @@ class Rcnn(nn.Cell): | |||||
| if self.training: | if self.training: | ||||
| bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels | bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels | ||||
| labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type) | |||||
| labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value) | |||||
| bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) | bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) | ||||
| loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) | loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) | ||||