From 5962c6efe92b9baa10cac35396131f2bbdcf8ade Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 19 Jun 2020 15:13:58 +0800 Subject: [PATCH] solve broadcast two same shape bprop error make unsupported shape error info explicit --- mindspore/ops/_grad/grad_array_ops.py | 4 ++++ mindspore/ops/operations/array_ops.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 1be108d3a7..923a8783b3 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self): def bprop(x, out, dout): x_shape = shape_op(x) + dout_shape = shape_op(dout) + + if x_shape == dout_shape: + return (dout,) _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape) reduced_grad = reduce_keep_dim(dout, reduction_axes) dx = reshape(reduced_grad, x_shape) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 1dbbe3c42b..8c87a6c5e9 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2719,6 +2719,8 @@ class BatchToSpaceND(PrimitiveWithInfer): class BroadcastTo(PrimitiveWithInfer): """ Broadcasts input tensor to a given shape. + Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one. + When input shape is broadcast to target shape, it starts with the trailing dimensions. Args: shape (tuple): The target shape to broadcast. @@ -2741,11 +2743,20 @@ class BroadcastTo(PrimitiveWithInfer): def __init__(self, shape): """Init BroadcastTo""" validator.check_value_type("shape", shape, (tuple), self.name) + validator.check("shape length", len(shape), "", 0, Rel.GT, self.name) for i in shape: validator.check_integer("shape element", i, 0, Rel.GT, self.name) self.shape = shape def infer_shape(self, x_shape): + validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) + + reversed_x_shape = tuple(reversed(x_shape)) + reversed_target = tuple(reversed(self.shape)) + for i, v in enumerate(reversed_x_shape): + if v not in (reversed_target[i], 1): + raise ValueError(f"Not supported shapes for broadcast, " + f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") return self.shape def infer_dtype(self, x_dtype):