diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 41d81fef4c..7de00d5160 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -729,62 +729,6 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & return ret; } -AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto x = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(x->shape()); - - ShapeVector x_shape = x->shape()->shape(); - ShapeVector x_shape_min = x->shape()->min_shape(); - if (x_shape_min.empty()) { - x_shape_min = x_shape; - } - ShapeVector x_shape_max = x->shape()->max_shape(); - if (x_shape_max.empty()) { - x_shape_max = x_shape; - } - - int64_t value = 0; - if (args_spec_list[1]->isa()) { // axis is Tensor - auto axis = CheckArg(op_name, args_spec_list, 1); - auto axis_value = axis->BuildValue(); - if (!axis_value->isa()) { - MS_LOG(EXCEPTION) << axis_value << " axis_value should be tensor, but got " << axis_value->type_name(); - } - auto axis_tensor = axis_value->cast(); - value = *(static_cast(axis_tensor->data_c())); - } else if (args_spec_list[1]->isa()) { // axis is Scalar - auto axis = CheckArg(op_name, args_spec_list, 1); - MS_EXCEPTION_IF_NULL(axis); - value = GetValue(axis->BuildValue()); - } else { - MS_LOG(EXCEPTION) << "axis incorrect type in ExpandDims"; - } - - if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) { - MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value - << " and input_x.dim is" << x_shape.size(); - } - if (value < 0) { - value = value + SizeToInt(x_shape.size()) + 1; - } - ShapeVector shape; - shape.insert(shape.end(), x_shape.begin(), x_shape.end()); - shape.insert(shape.begin() + value, 1); - ShapeVector shape_min; - shape_min.insert(shape_min.end(), x_shape_min.begin(), x_shape_min.end()); - shape_min.insert(shape_min.begin() + value, 1); - ShapeVector shape_max; - shape_max.insert(shape_max.end(), x_shape_max.begin(), x_shape_max.end()); - shape_max.insert(shape_max.begin() + value, 1); - - auto ret = std::make_shared(x->element(), std::make_shared(shape, shape_min, shape_max)); - return ret; -} - AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 8726d36b88..46a30b3e6c 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -492,6 +492,32 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri return ret; } +AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + + std::vector shape; + std::vector x_shape = x->shape()->shape(); + shape.insert(shape.end(), x_shape.begin(), x_shape.end()); + auto axis = primitive->GetAttr("axis"); + auto value = GetValue(axis); + if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) { + MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value + << " and input_x.dim is" << x_shape.size(); + } + if (value < 0) { + value = value + SizeToInt(x_shape.size()) + 1; + } + shape.insert(shape.begin() + value, 1); + + auto ret = std::make_shared(x->element(), std::make_shared(shape)); + return ret; +} + AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 9a918aee68..185738c660 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -118,7 +118,7 @@ def _check_infer_attr_reduce(axis, keep_dims, prim_name): validator.check_value_type('axis[%d]' % index, value, [int], prim_name) -class ExpandDims(PrimitiveWithCheck): +class ExpandDims(PrimitiveWithInfer): """ Adds an additional dimension at the given axis. @@ -156,13 +156,29 @@ class ExpandDims(PrimitiveWithCheck): """Initialize ExpandDims""" self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) - def __check__(self, x, axis): + def __infer__(self, x, axis): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) - validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) x_shape = list(x['shape']) axis_v = axis['value'] rank = len(x_shape) validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name) + value = None + if x['value'] is not None: + value = x['value'].asnumpy() + value = np.expand_dims(value, axis_v) + value = Tensor(value) + if axis_v < 0: + axis_v = rank + 1 + axis_v + x_shape.insert(axis_v, 1) + out = {'shape': x_shape, + 'dtype': x['dtype'], + 'value': value} + if 'min_shape' in x and 'max_shape' in x: + out['min_shape'] = x['min_shape'] + out['min_shape'].insert(axis_v, 1) + out['max_shape'] = x['max_shape'] + out['max_shape'].insert(axis_v, 1) + return out class DType(PrimitiveWithInfer):