Browse Source

!2356 fix issue broadcast_to same shape bprop error

Merge pull request !2356 from zhaozhenlong/fix-issues-broadcastto-same-shape
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
d840743fba
2 changed files with 15 additions and 0 deletions
  1. +4
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +11
    -0
      mindspore/ops/operations/array_ops.py

+ 4
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self):

def bprop(x, out, dout):
x_shape = shape_op(x)
dout_shape = shape_op(dout)

if x_shape == dout_shape:
return (dout,)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape)


+ 11
- 0
mindspore/ops/operations/array_ops.py View File

@@ -2716,6 +2716,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
When input shape is broadcast to target shape, it starts with the trailing dimensions.

Args:
shape (tuple): The target shape to broadcast.
@@ -2738,11 +2740,20 @@ class BroadcastTo(PrimitiveWithInfer):
def __init__(self, shape):
"""Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape

def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)

reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(self.shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape

def infer_dtype(self, x_dtype):


Loading…
Cancel
Save