Browse Source

solve broadcast two same shape bprop error

make unsupported shape error info explicit
tags/v0.5.0-beta
zhaozhenlong 5 years ago
parent
commit
5962c6efe9
2 changed files with 15 additions and 0 deletions
  1. +4
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +11
    -0
      mindspore/ops/operations/array_ops.py

+ 4
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self):


def bprop(x, out, dout): def bprop(x, out, dout):
x_shape = shape_op(x) x_shape = shape_op(x)
dout_shape = shape_op(dout)

if x_shape == dout_shape:
return (dout,)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape) _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes) reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape) dx = reshape(reduced_grad, x_shape)


+ 11
- 0
mindspore/ops/operations/array_ops.py View File

@@ -2719,6 +2719,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
class BroadcastTo(PrimitiveWithInfer): class BroadcastTo(PrimitiveWithInfer):
""" """
Broadcasts input tensor to a given shape. Broadcasts input tensor to a given shape.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
When input shape is broadcast to target shape, it starts with the trailing dimensions.


Args: Args:
shape (tuple): The target shape to broadcast. shape (tuple): The target shape to broadcast.
@@ -2741,11 +2743,20 @@ class BroadcastTo(PrimitiveWithInfer):
def __init__(self, shape): def __init__(self, shape):
"""Init BroadcastTo""" """Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name) validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape: for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name) validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape self.shape = shape


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)

reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(self.shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape return self.shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):


Loading…
Cancel
Save