From: @simson_wu Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qhpull/15522/MERGE
| @@ -48,7 +48,16 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return std::make_shared<abstract::Shape>(input_x); | |||||
| auto x_shape_ptr = std::make_shared<abstract::Shape>(input_x); | |||||
| primitive->AddAttr("shape", MakeValue(input_x)); | |||||
| for (int64_t i = 0; i < (int64_t)x_shape.size(); i++) { | |||||
| if (input_x[i + outer_dim_offset] != x_shape[i] && x_shape[i] != 1) { | |||||
| MS_EXCEPTION(ValueError) << "Not support shapes for broadcast, x_shape: " | |||||
| << input_args[0]->BuildShape()->ToString() | |||||
| << ", target shape: " << x_shape_ptr->ToString(); | |||||
| } | |||||
| } | |||||
| return x_shape_ptr; | |||||
| } | } | ||||
| TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| @@ -980,11 +980,11 @@ def get_bprop_batch_to_space_nd(self): | |||||
| def get_bprop_broadcast_to(self): | def get_bprop_broadcast_to(self): | ||||
| """Generate bprop for BroadcastTo""" | """Generate bprop for BroadcastTo""" | ||||
| reduce_keep_dim = P.ReduceSum(keep_dims=True) | reduce_keep_dim = P.ReduceSum(keep_dims=True) | ||||
| broadcast_shape = self.shape | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| x_shape = shape_op(x) | x_shape = shape_op(x) | ||||
| dout_shape = shape_op(dout) | dout_shape = shape_op(dout) | ||||
| broadcast_shape = shape_op(out) | |||||
| if x_shape == dout_shape: | if x_shape == dout_shape: | ||||
| return (dout,) | return (dout,) | ||||
| @@ -861,7 +861,7 @@ class GatherV2(PrimitiveWithCheck): | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | ||||
| class SparseGatherV2(Gather): | |||||
| class SparseGatherV2(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Returns a slice of input tensor based on the specified indices and axis. | Returns a slice of input tensor based on the specified indices and axis. | ||||
| @@ -893,6 +893,22 @@ class SparseGatherV2(Gather): | |||||
| [2. 55.]] | [2. 55.]] | ||||
| """ | """ | ||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize index_select""" | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||||
| def __check__(self, params, indices, axis): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) | |||||
| axis_v = axis['value'] | |||||
| validator.check_value_type('axis', axis_v, [int], self.name) | |||||
| rank = len(params['shape']) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| class Padding(PrimitiveWithInfer): | class Padding(PrimitiveWithInfer): | ||||
| """ | """ | ||||