From 39f424fdd75e78d7bce38828900a141a1dd42012 Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Thu, 18 Feb 2021 16:27:50 +0800 Subject: [PATCH] fix bugs of c++ infer --- mindspore/core/abstract/infer_functions.h | 4 + mindspore/core/abstract/prim_nn.cc | 17 +++++ mindspore/core/abstract/prim_others.cc | 16 ++++ .../core/abstract/primitive_infer_map.cc | 2 + mindspore/core/base/core_ops.h | 1 + mindspore/core/ops/apply_momentum.cc | 27 +++---- mindspore/core/ops/assert.cc | 2 +- mindspore/core/ops/bias_add.cc | 42 +++++++++++ mindspore/core/ops/conv2d.cc | 42 ++++++----- mindspore/core/ops/fill.cc | 20 ++++- mindspore/core/ops/gather.cc | 6 +- .../core/ops/grad/conv2d_backprop_filter.cc | 4 - .../core/ops/grad/conv2d_backprop_input.cc | 57 ++++++++++++++ mindspore/core/ops/grad/max_pool_grad.cc | 4 +- mindspore/core/ops/mat_mul.cc | 75 +++++++++++++++++++ mindspore/core/ops/max_pool.cc | 30 +++++--- mindspore/core/ops/merge.cc | 4 +- mindspore/core/ops/shape.cc | 14 ++-- mindspore/core/ops/zeros_like.cc | 1 - mindspore/core/utils/check_convert_utils.cc | 15 ++-- mindspore/core/utils/check_convert_utils.h | 4 +- .../core/utils/tensor_construct_utils.cc | 19 +---- mindspore/core/utils/tensor_construct_utils.h | 10 +++ 23 files changed, 322 insertions(+), 94 deletions(-) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index ab80f25d7c..92d8670e76 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -308,6 +308,10 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index a3c2ea2446..c5c620946a 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -202,6 +202,23 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti return std::make_shared(elements); } +AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + auto is_grad = GetValue(primitive->GetAttr("is_grad")); + CheckArgsSize(primitive->name(), args_spec_list, 2); + std::shared_ptr shape = std::make_shared(std::vector{}); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + if (is_grad) { + shape = args_spec_list[0]->BuildShape(); + } + auto type = args_spec_list[0]->BuildType(); + MS_EXCEPTION_IF_NULL(type); + auto type_tensor = type->cast(); + MS_EXCEPTION_IF_NULL(type_tensor); + return std::make_shared(type_tensor->element(), shape); +} + AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 4811e17208..26a9b080cf 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -559,5 +559,21 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con ShapePtr shape = std::make_shared(inferred_shape, min_shape, max_shape); return std::make_shared(input->element(), shape); } + +AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + auto type = args_spec_list[0]->BuildType(); + MS_EXCEPTION_IF_NULL(type); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto value = tensor_type->element(); + auto abstract = std::make_shared(value); + abstract->set_value(value); + return abstract; +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index c07d0074eb..6ee6156118 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -193,6 +193,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, {prim::kPrimCast, {InferImplCast, true}}, {prim::kPrimExpandDims, {InferImplExpandDims, true}}, + {prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, + {prim::kPrimDType, {InferImplDType, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index d8f6c44a64..3bf3bff11b 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -520,6 +520,7 @@ inline const PrimitivePtr kPrimTopKFusion = std::make_shared("TopKFus inline const PrimitivePtr kPrimTileFusion = std::make_shared("TileFusion"); inline const PrimitivePtr kPrimReduceFusion = std::make_shared("ReduceFusion"); inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared("LayerNormFusion"); +inline const PrimitivePtr kPrimDType = std::make_shared("DType"); class DoSignaturePrimitive : public Primitive { public: diff --git a/mindspore/core/ops/apply_momentum.cc b/mindspore/core/ops/apply_momentum.cc index 5b03ef3230..6be813667d 100644 --- a/mindspore/core/ops/apply_momentum.cc +++ b/mindspore/core/ops/apply_momentum.cc @@ -56,32 +56,29 @@ float ApplyMomentum::get_gradient_scale() const { AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto momentum_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(momentum_prim); - auto prim_name = momentum_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); // Infer shape auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); // Infer type - auto v_type = input_args[0]->BuildType()->cast()->element(); - auto a_type = input_args[1]->BuildType()->cast()->element(); - auto l_type = input_args[2]->BuildType()->cast()->element(); - auto g_type = input_args[3]->BuildType()->cast()->element(); - auto m_type = input_args[4]->BuildType()->cast()->element(); + auto v_tensor_type = input_args[0]->BuildType(); + auto a_tensor_type = input_args[1]->BuildType(); + auto l_type = input_args[2]->BuildType(); + auto g_type = input_args[3]->BuildType(); + auto m_type = input_args[4]->BuildType(); const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; - CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name); - CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name); - const std::set valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32), - TypeIdToType(kNumberTypeFloat64)}; + CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); + CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); std::map args; args.insert({"l_type", l_type}); args.insert({"g_type", g_type}); args.insert({"m_type", m_type}); - CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name); - - return std::make_shared(g_type, v_shape); + CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name); + auto g_type_tensor = g_type->cast(); + auto element = g_type_tensor->element(); + return std::make_shared(element, v_shape); } REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); diff --git a/mindspore/core/ops/assert.cc b/mindspore/core/ops/assert.cc index e0bb23d1dc..9393a9fe0b 100644 --- a/mindspore/core/ops/assert.cc +++ b/mindspore/core/ops/assert.cc @@ -61,7 +61,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive condition = input_args[0]->BuildType(); } std::vector output_shape = {1}; - std::set local_bool = {TypeIdToType(kNumberTypeBool)}; + std::set local_bool = {kNumberTypeBool}; std::map args = {{"condition", condition}}; CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); auto inputs_type = input_args[1]->BuildType()->cast()->elements(); diff --git a/mindspore/core/ops/bias_add.cc b/mindspore/core/ops/bias_add.cc index 6b3331eab6..b1009bcc40 100644 --- a/mindspore/core/ops/bias_add.cc +++ b/mindspore/core/ops/bias_add.cc @@ -23,6 +23,42 @@ namespace mindspore { namespace ops { +// Add +namespace { +abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); + CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + auto x_channel = x_shape[1]; + if (format != NCHW) { + x_channel = x_shape[x_shape.size() - 1]; + } + CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, prim_name); + + std::vector out_shape = x_shape; + return std::make_shared(out_shape); +} + +TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + std::map types; + types.emplace("input_x", input_args[0]->BuildType()); + types.emplace("bias", input_args[1]->BuildType()); + auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return TypeIdToType(infer_type); +} +} // namespace void BiasAdd::set_format(const Format &format) { int64_t f = format; this->AddAttr(kFormat, MakeValue(f)); @@ -32,7 +68,13 @@ Format BiasAdd::get_format() const { return Format(GetValue(value_ptr)); } void BiasAdd::Init(const Format &format) { this->set_format(format); } +AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(BiasAddInferType(primitive, input_args), + BiasAddInferShape(primitive, input_args)); +} // Add +REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer); REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 98bf0ed2c4..8d88d2e189 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -30,32 +30,31 @@ namespace ops { namespace { abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto conv_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(conv_prim); - auto prim_name = conv_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); - if (conv_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; } CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", - w_shape[1], conv_prim->name()); - auto out_channel = conv_prim->get_out_channel(); - CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); + CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue(primitive->GetAttr(kGroup)), kEqual, + "w_shape[1]", w_shape[1], prim_name); + auto out_channel = GetValue(primitive->GetAttr(kOutChannel)); + CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); std::vector temp_w; std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); - CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, - conv_prim->name()); + CheckAndConvertUtils::Check("kernel_size", GetValue>(primitive->GetAttr(kKernelSize)), kEqual, + "w_shape[2:4]", temp_w, prim_name); auto kernel_size_h = w_shape[2]; auto kernel_size_w = w_shape[3]; - auto stride = conv_prim->get_stride(); - auto dilation = conv_prim->get_dilation(); + auto stride = GetValue>(primitive->GetAttr(kStride)); + auto dilation = GetValue>(primitive->GetAttr(kDilation)); auto stride_h = stride[2]; auto stride_w = stride[3]; auto dilation_h = dilation[2]; @@ -63,7 +62,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve int64_t h_out = -1; int64_t w_out = -1; std::vector pad_list(4, 0); - auto pad_mode = conv_prim->get_pad_mode(); + auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); if (pad_mode == VALID) { h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); @@ -81,20 +80,23 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve pad_list.emplace_back(pad_left); pad_list.emplace_back(pad_needed_h - pad_left); } else if (pad_mode == PAD) { - std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); - auto pad_top = conv_prim->get_pad()[0]; - auto pad_bottom = conv_prim->get_pad()[1]; - auto pad_right = conv_prim->get_pad()[2]; - auto pad_left = conv_prim->get_pad()[3]; + auto pad = GetValue>(primitive->GetAttr(kPad)); + std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list)); + auto pad_top = pad[0]; + auto pad_bottom = pad[1]; + auto pad_right = pad[2]; + auto pad_left = pad[3]; h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; h_out = floor(h_out); w_out = floor(w_out); } - conv_prim->set_pad(pad_list); + CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); + primitive->AddAttr(kPadList, + MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true))); std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; - if (conv_prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {x_shape[0], h_out, w_out, out_channel}; } diff --git a/mindspore/core/ops/fill.cc b/mindspore/core/ops/fill.cc index f90781c951..0efddf5335 100644 --- a/mindspore/core/ops/fill.cc +++ b/mindspore/core/ops/fill.cc @@ -15,8 +15,10 @@ */ #include "ops/fill.h" +#include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" namespace mindspore { namespace ops { @@ -38,7 +40,23 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt valid_types.insert(kNumberTypeBool); CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name); auto out_shape = GetValue>(input_args[1]->BuildValue()); - return std::make_shared(dtype, std::make_shared(out_shape)); + auto x_type = input_args[2]->BuildType(); + auto x_type_id = x_type->type_id(); + auto x_value = input_args[2]->BuildValue(); + auto abs = std::make_shared(dtype, std::make_shared(out_shape)); + tensor::TensorPtr tensor = std::make_shared(x_type_id, out_shape); + auto mem_size = IntToSize(tensor->ElementsNum()); + if (x_type_id == kNumberTypeInt) { + auto num = GetValue(x_value); + SetTensorData(tensor->data_c(), num, mem_size); + } else if (x_type_id == kNumberTypeFloat || x_type_id == kNumberTypeFloat32) { + auto num = GetValue(x_value); + SetTensorData(tensor->data_c(), num, mem_size); + } else { + MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString(); + } + abs->set_value(tensor); + return abs; } REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer); REGISTER_PRIMITIVE_C(kNameFill, Fill); diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index 55fcc1b5af..6c18c5da34 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -23,15 +23,11 @@ namespace ops { AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto gather_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(gather_prim); - auto prim_name = gather_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); // Infer type auto x_type = input_args[0]->BuildType()->cast()->element(); - // auto dim_type = input_args[1]->BuildType(); - // auto index_type = input_args[2]->BuildType()->cast()->element(); std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); const std::set valid_index_types = {kNumberTypeInt32, kNumberTypeInt64}; diff --git a/mindspore/core/ops/grad/conv2d_backprop_filter.cc b/mindspore/core/ops/grad/conv2d_backprop_filter.cc index 7c51ade8b4..3d06efe7e9 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_filter.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_filter.cc @@ -27,10 +27,6 @@ namespace { abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto conv2d_backprop_filter_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(conv2d_backprop_filter_prim); - // auto prim_name = conv2d_backprop_filter_prim->name(); - auto out_put = input_args[2]->BuildValue(); auto infer_shape = GetValue>(out_put); return std::make_shared(infer_shape); diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index 6c2eeee8f4..344ffe949c 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -23,6 +23,62 @@ namespace mindspore { namespace ops { +AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); + for (auto item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto doutput = input_args[0]; + auto x_size = input_args[2]; + auto x_size_value = x_size->GetValueTrack(); + MS_EXCEPTION_IF_NULL(x_size); + auto x_size_v = GetValue>(x_size_value); + // infer dtype + auto dtype = doutput->BuildType(); + if (!dtype->isa()) { + MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); + } + auto input_tensor_type = dtype->cast(); + MS_EXCEPTION_IF_NULL(input_tensor_type); + auto element = input_tensor_type->element(); + // infer shape + auto dout_shape = doutput->BuildShape(); + MS_EXCEPTION_IF_NULL(doutput); + auto dout_shapeptr = dout_shape->cast(); + auto dout_shape_norm = dout_shapeptr->shape(); + auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); + auto stride = GetValue>(primitive->GetAttr(kStride)); + auto dilation = GetValue>(primitive->GetAttr(kStride)); + auto pad_list = GetValue>(primitive->GetAttr(kPadList)); + auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); + if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) { + primitive->AddAttr(kPadList, MakeValue(pad_list)); + } else if (pad_mode == SAME) { + auto stride_h = stride[0]; + auto stride_w = stride[1]; + auto kernel_h = kernel_size[0]; + auto kernel_w = kernel_size[1]; + auto dilation_h = dilation[2]; + auto dilation_w = dilation[3]; + auto pad_needed_h = (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]; + pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h; + auto pad_top = pad_needed_h / 2; + auto pad_bottom = pad_needed_h - pad_top; + auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2]; + pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L; + auto pad_left = pad_needed_w / 2; + auto pad_right = pad_needed_w - pad_left; + pad_list = {pad_top, pad_bottom, pad_left, pad_right}; + } else if (pad_mode == PAD) { + pad_list = GetValue>(primitive->GetAttr(kPad)); + } + primitive->AddAttr(kPadList, MakeValue(pad_list)); + return std::make_shared(element, std::make_shared(x_size_v)); +} + void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector &kernel_size, int64_t mode, const PadMode &pad_mode, const std::vector &pad, const std::vector &stride, const std::vector &dilation, int64_t group, @@ -140,6 +196,7 @@ std::vector Conv2DBackpropInput::get_pad_list() const { auto value_ptr = GetAttr(kPadList); return GetValue>(value_ptr); } +REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer); REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/max_pool_grad.cc b/mindspore/core/ops/grad/max_pool_grad.cc index 14b2667f12..8080687318 100644 --- a/mindspore/core/ops/grad/max_pool_grad.cc +++ b/mindspore/core/ops/grad/max_pool_grad.cc @@ -52,9 +52,7 @@ void MaxPoolGrad::set_strides(const std::vector &strides) { AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - auto MaxPoolGrad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); - auto op_name = MaxPoolGrad_prim->name(); + auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); auto tensor_type = input_args[0]->BuildType()->cast(); diff --git a/mindspore/core/ops/mat_mul.cc b/mindspore/core/ops/mat_mul.cc index 6d6c77e3e2..678d8e3e2e 100644 --- a/mindspore/core/ops/mat_mul.cc +++ b/mindspore/core/ops/mat_mul.cc @@ -21,6 +21,81 @@ namespace mindspore { namespace ops { +namespace { +abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); + auto trans_a = GetValue(primitive->GetAttr(kTransposeA)); + auto trans_b = GetValue(primitive->GetAttr(kTransposeB)); + + auto out_n = x_shape[0]; + auto out_m = w_shape[1]; + auto x_C = x_shape[1]; + auto w_C = w_shape[0]; + + if (trans_a) { + out_n = x_shape[1]; + x_C = x_shape[0]; + } + if (trans_b) { + out_m = w_shape[0]; + w_C = w_shape[1]; + } + CheckAndConvertUtils::CheckInteger("dim C is not equal", x_C, kEqual, w_C, prim_name); + primitive->AddAttr("transpose_x1", MakeValue(trans_a)); + primitive->AddAttr("transpose_x2", MakeValue(trans_b)); + std::vector out_shape = {out_n, out_m}; + return std::make_shared(out_shape); +} + +TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, + kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; + std::map types; + types.emplace("x", input_args[0]->BuildType()); + types.emplace("w", input_args[1]->BuildType()); + auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + if (infer_type == kNumberTypeInt8) { + return std::make_shared(TypeIdToType(kNumberTypeInt32)); + } + return TypeIdToType(infer_type); +} +} // namespace + +void MatMul::Init(bool transpose_a, bool transpose_b) { + set_transpose_a(transpose_a); + set_transpose_b(transpose_b); +} + +void MatMul::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); } + +void MatMul::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); } + +bool MatMul::get_transpose_a() const { + auto value_ptr = GetAttr(kTransposeA); + return GetValue(value_ptr); +} + +bool MatMul::get_transpose_b() const { + auto value_ptr = GetAttr(kTransposeB); + return GetValue(value_ptr); +} + +// Add +AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(MatMulInferType(primitive, input_args), + MatMulInferShape(primitive, input_args)->shape()); +} + +// Add +REGISTER_PRIMITIVE_EVAL_IMPL(MatMul, prim::kPrimMatMul, MatMulInfer); REGISTER_PRIMITIVE_C(kNameMatMul, MatMul); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/max_pool.cc b/mindspore/core/ops/max_pool.cc index b43f50c8a4..516b287556 100644 --- a/mindspore/core/ops/max_pool.cc +++ b/mindspore/core/ops/max_pool.cc @@ -94,22 +94,22 @@ void MaxPool::Init(const std::vector &kernel_size, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pool_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pool_prim); - auto op_name = pool_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); - if (pool_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; } CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); - auto kernel_size = pool_prim->get_kernel_size(); - auto pad_mode = pool_prim->get_pad_mode(); + auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); + auto pad_mode_value = (primitive->GetAttr(kPadMode)); + PadMode pad_mode = PAD; + pad_mode = PadMode(GetValue(pad_mode_value)); auto batch = in_shape[0]; auto channel = in_shape[1]; auto in_h = in_shape[2]; auto in_w = in_shape[3]; - - auto strides = pool_prim->get_strides(); + auto strides = GetValue>(primitive->GetAttr(kStrides)); auto kernel_h = kernel_size[2]; auto kernel_w = kernel_size[3]; auto stride_h = strides[2]; @@ -117,14 +117,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector out_shape = {batch, channel, out_h, out_w}; - if (pool_prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {batch, out_h, out_w, channel}; } if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { @@ -137,7 +137,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } - return input_args[0]->BuildType(); + auto input_type = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(input_type); + auto input_tensor_type = input_type->cast(); + if (input_tensor_type == nullptr) { + MS_LOG_EXCEPTION << "The maxpool's input must be a tensor but got " << input_type->ToString(); + } + return input_tensor_type->element(); } } // namespace diff --git a/mindspore/core/ops/merge.cc b/mindspore/core/ops/merge.cc index 6378a2a56f..76176464ca 100644 --- a/mindspore/core/ops/merge.cc +++ b/mindspore/core/ops/merge.cc @@ -38,9 +38,9 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) { args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]}); } - std::set template_type = {TypeIdToType(kNumberTypeBool)}; + std::set template_type = {kNumberTypeBool}; for (auto item : common_valid_types) { - template_type.insert(TypeIdToType(item)); + template_type.insert(item); } CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); std::vector in_shape0 = inputs_shape[0]->cast()->shape(); diff --git a/mindspore/core/ops/shape.cc b/mindspore/core/ops/shape.cc index 507c6e3179..32f02b7af7 100644 --- a/mindspore/core/ops/shape.cc +++ b/mindspore/core/ops/shape.cc @@ -30,13 +30,17 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP const std::vector &input_args) { // infer shape MS_EXCEPTION_IF_NULL(primitive); - auto shape_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(shape_prim); - auto op_name = shape_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); // infer type - auto x_type = input_args[0]->BuildType()->cast()->element(); - return std::make_shared(x_type, in_shape); + AbstractBasePtrList abs_list; + std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), + [](int64_t item) -> std::shared_ptr { + return std::make_shared(item); + }); + auto abs = std::make_shared(abs_list); + abs->set_value(MakeValue(in_shape)); + return abs; } REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer); REGISTER_PRIMITIVE_C(kNameShape, Shape); diff --git a/mindspore/core/ops/zeros_like.cc b/mindspore/core/ops/zeros_like.cc index 5bfe3e106f..9a38ecc988 100644 --- a/mindspore/core/ops/zeros_like.cc +++ b/mindspore/core/ops/zeros_like.cc @@ -60,7 +60,6 @@ AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const Primit return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer); REGISTER_PRIMITIVE_C(kNameZerosLike, ZerosLike); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index bdc22af34a..d09d4dbf9c 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -486,7 +486,7 @@ void CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const Typ } void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map &args, - const std::set &valid_values, + const std::set &valid_values, const std::string &prim_name, const bool allow_mix) { std::vector> check_results; for (auto &iter : args) { @@ -502,7 +502,7 @@ void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map CheckAndConvertUtils::_CheckArgumentType(const std::map &arg, - const std::set &valid_values, + const std::set &valid_values, const std::string &prim_name) { std::string arg_key = arg.begin()->first; TypePtr arg_val = arg.begin()->second; @@ -512,15 +512,16 @@ std::map CheckAndConvertUtils::_CheckArgumentType(const st arg_val = arg_val_->element(); } - auto it = valid_values.find(arg_val); + auto it = valid_values.find(arg_val->type_id()); if (it == valid_values.end()) { std::ostringstream buffer; buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { "; for (auto valid_value : valid_values) { - buffer << valid_value->ToString() << " },"; - buffer << "but `" << arg_key << "`" - << "is" << arg_val->ToString() << "."; + buffer << TypeIdToType(valid_value)->ToString() << ","; } + buffer << " },"; + buffer << "but `" << arg_key << "`" + << "is" << arg_val->ToString() << "."; MS_EXCEPTION(TypeError) << buffer.str(); } return arg; @@ -546,7 +547,7 @@ std::map CheckAndConvertUtils::_CheckTypeSame(const std::m except_flag = true; } - if (except_flag || arg1_type != arg2_type) { + if (except_flag || arg1_type->type_id() != arg2_type->type_id()) { std::ostringstream buffer; buffer << "For '" << prim_name << "'" << "type of " diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index 677fdc1cae..fc17e2aaa0 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -277,7 +277,7 @@ class CheckAndConvertUtils { static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set &template_types, const std::string &prim_name); static void CheckScalarOrTensorTypesSame(const std::map &args, - const std::set &valid_values, const std::string &prim_name, + const std::set &valid_values, const std::string &prim_name, bool allow_mix = false); static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set &valid_type, const std::string &prim_name); @@ -291,7 +291,7 @@ class CheckAndConvertUtils { private: static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2); static std::map _CheckArgumentType(const std::map &arg, - const std::set &valid_values, + const std::set &valid_values, const std::string &prim_name); static std::map _CheckTypeSame(const std::map &arg1, const std::map &arg2, diff --git a/mindspore/core/utils/tensor_construct_utils.cc b/mindspore/core/utils/tensor_construct_utils.cc index 6382ec0fe1..21c93d9d79 100644 --- a/mindspore/core/utils/tensor_construct_utils.cc +++ b/mindspore/core/utils/tensor_construct_utils.cc @@ -17,21 +17,8 @@ #include #include namespace mindspore { -namespace { -template -void SetTensorData(void *data, float num, size_t data_length) { - MS_EXCEPTION_IF_NULL(data); - auto tensor_data = reinterpret_cast(data); - MS_EXCEPTION_IF_NULL(tensor_data); - for (size_t index = 0; index < data_length; ++index) { - *tensor_data = num; - ++tensor_data; - } -} -} // namespace tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector &shape) { tensor::TensorPtr tensor = std::make_shared(type, shape); - size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); auto tensor_data = tensor->data_c(); char *data = reinterpret_cast(tensor_data); @@ -43,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector &shape) { tensor::TensorPtr tensor = std::make_shared(type, shape); - auto mem_size = IntToSize(tensor->ElementsNum()); + size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); if (tensor->data_type() == kNumberTypeFloat32) { - SetTensorData(tensor->data_c(), 1.0, mem_size); + SetTensorData(tensor->data_c(), 1.0, mem_size); } else if (tensor->data_type() == kNumberTypeInt) { - SetTensorData(tensor->data_c(), 1, mem_size); + SetTensorData(tensor->data_c(), 1, mem_size); } return tensor; } diff --git a/mindspore/core/utils/tensor_construct_utils.h b/mindspore/core/utils/tensor_construct_utils.h index 4094c67ace..4bb87ab27b 100644 --- a/mindspore/core/utils/tensor_construct_utils.h +++ b/mindspore/core/utils/tensor_construct_utils.h @@ -18,6 +18,16 @@ #include #include "ir/tensor.h" namespace mindspore { +template +void SetTensorData(void *data, T num, size_t data_length) { + MS_EXCEPTION_IF_NULL(data); + auto tensor_data = reinterpret_cast(data); + MS_EXCEPTION_IF_NULL(tensor_data); + for (size_t index = 0; index < data_length; ++index) { + *tensor_data = num; + ++tensor_data; + } +} class TensorConstructUtils { public: static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector &shape);