Browse Source

!15522 fix bug of sparsegather and broadcastto

From: @simson_wu
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/15522/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
f31c7c3fff
3 changed files with 28 additions and 3 deletions
  1. +10
    -1
      mindspore/core/ops/broadcast_to.cc
  2. +1
    -1
      mindspore/ops/_grad/grad_array_ops.py
  3. +17
    -1
      mindspore/ops/operations/array_ops.py

+ 10
- 1
mindspore/core/ops/broadcast_to.cc View File

@@ -48,7 +48,16 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
} }
} }
} }
return std::make_shared<abstract::Shape>(input_x);
auto x_shape_ptr = std::make_shared<abstract::Shape>(input_x);
primitive->AddAttr("shape", MakeValue(input_x));
for (int64_t i = 0; i < (int64_t)x_shape.size(); i++) {
if (input_x[i + outer_dim_offset] != x_shape[i] && x_shape[i] != 1) {
MS_EXCEPTION(ValueError) << "Not support shapes for broadcast, x_shape: "
<< input_args[0]->BuildShape()->ToString()
<< ", target shape: " << x_shape_ptr->ToString();
}
}
return x_shape_ptr;
} }


TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {


+ 1
- 1
mindspore/ops/_grad/grad_array_ops.py View File

@@ -980,11 +980,11 @@ def get_bprop_batch_to_space_nd(self):
def get_bprop_broadcast_to(self): def get_bprop_broadcast_to(self):
"""Generate bprop for BroadcastTo""" """Generate bprop for BroadcastTo"""
reduce_keep_dim = P.ReduceSum(keep_dims=True) reduce_keep_dim = P.ReduceSum(keep_dims=True)
broadcast_shape = self.shape


def bprop(x, out, dout): def bprop(x, out, dout):
x_shape = shape_op(x) x_shape = shape_op(x)
dout_shape = shape_op(dout) dout_shape = shape_op(dout)
broadcast_shape = shape_op(out)


if x_shape == dout_shape: if x_shape == dout_shape:
return (dout,) return (dout,)


+ 17
- 1
mindspore/ops/operations/array_ops.py View File

@@ -861,7 +861,7 @@ class GatherV2(PrimitiveWithCheck):
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)




class SparseGatherV2(Gather):
class SparseGatherV2(PrimitiveWithCheck):
""" """
Returns a slice of input tensor based on the specified indices and axis. Returns a slice of input tensor based on the specified indices and axis.


@@ -893,6 +893,22 @@ class SparseGatherV2(Gather):
[2. 55.]] [2. 55.]]
""" """


@prim_attr_register
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])


def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)




class Padding(PrimitiveWithInfer): class Padding(PrimitiveWithInfer):
""" """


Loading…
Cancel
Save