From: @lianliguang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -308,6 +308,10 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tuple or list or dict. | // Inputs: a tuple or list or dict. | ||||
| @@ -202,6 +202,23 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti | |||||
| return std::make_shared<AbstractTuple>(elements); | return std::make_shared<AbstractTuple>(elements); | ||||
| } | } | ||||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto is_grad = GetValue<bool>(primitive->GetAttr("is_grad")); | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| std::shared_ptr<BaseShape> shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{}); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| if (is_grad) { | |||||
| shape = args_spec_list[0]->BuildShape(); | |||||
| } | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| auto type_tensor = type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(type_tensor); | |||||
| return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape); | |||||
| } | |||||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | ||||
| @@ -559,5 +559,21 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con | |||||
| ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape); | ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape); | ||||
| return std::make_shared<AbstractTensor>(input->element(), shape); | return std::make_shared<AbstractTensor>(input->element(), shape); | ||||
| } | } | ||||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| auto tensor_type = type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto value = tensor_type->element(); | |||||
| auto abstract = std::make_shared<abstract::AbstractType>(value); | |||||
| abstract->set_value(value); | |||||
| return abstract; | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -193,6 +193,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | ||||
| {prim::kPrimCast, {InferImplCast, true}}, | {prim::kPrimCast, {InferImplCast, true}}, | ||||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | ||||
| {prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, | |||||
| {prim::kPrimDType, {InferImplDType, true}}, | |||||
| }; | }; | ||||
| return prim_eval_implement_map; | return prim_eval_implement_map; | ||||
| } | } | ||||
| @@ -521,6 +521,7 @@ inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFus | |||||
| inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); | inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); | ||||
| inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | ||||
| inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); | inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); | ||||
| inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType"); | |||||
| class DoSignaturePrimitive : public Primitive { | class DoSignaturePrimitive : public Primitive { | ||||
| public: | public: | ||||
| @@ -56,32 +56,29 @@ float ApplyMomentum::get_gradient_scale() const { | |||||
| AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(momentum_prim); | |||||
| auto prim_name = momentum_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); | auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); | ||||
| // Infer type | // Infer type | ||||
| auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto v_tensor_type = input_args[0]->BuildType(); | |||||
| auto a_tensor_type = input_args[1]->BuildType(); | |||||
| auto l_type = input_args[2]->BuildType(); | |||||
| auto g_type = input_args[3]->BuildType(); | |||||
| auto m_type = input_args[4]->BuildType(); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | ||||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name); | |||||
| const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32), | |||||
| TypeIdToType(kNumberTypeFloat64)}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); | |||||
| std::map<std::string, TypePtr> args; | std::map<std::string, TypePtr> args; | ||||
| args.insert({"l_type", l_type}); | args.insert({"l_type", l_type}); | ||||
| args.insert({"g_type", g_type}); | args.insert({"g_type", g_type}); | ||||
| args.insert({"m_type", m_type}); | args.insert({"m_type", m_type}); | ||||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(g_type, v_shape); | |||||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name); | |||||
| auto g_type_tensor = g_type->cast<TensorTypePtr>(); | |||||
| auto element = g_type_tensor->element(); | |||||
| return std::make_shared<abstract::AbstractTensor>(element, v_shape); | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); | ||||
| REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | ||||
| @@ -61,7 +61,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| condition = input_args[0]->BuildType(); | condition = input_args[0]->BuildType(); | ||||
| } | } | ||||
| std::vector<int64_t> output_shape = {1}; | std::vector<int64_t> output_shape = {1}; | ||||
| std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)}; | |||||
| std::set<TypeId> local_bool = {kNumberTypeBool}; | |||||
| std::map<std::string, TypePtr> args = {{"condition", condition}}; | std::map<std::string, TypePtr> args = {{"condition", condition}}; | ||||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); | CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); | ||||
| auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements(); | auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements(); | ||||
| @@ -23,6 +23,42 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| // Add | |||||
| namespace { | |||||
| abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| // check | |||||
| CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| auto x_channel = x_shape[1]; | |||||
| if (format != NCHW) { | |||||
| x_channel = x_shape[x_shape.size() - 1]; | |||||
| } | |||||
| CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, prim_name); | |||||
| std::vector<int64_t> out_shape = x_shape; | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("input_x", input_args[0]->BuildType()); | |||||
| types.emplace("bias", input_args[1]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| void BiasAdd::set_format(const Format &format) { | void BiasAdd::set_format(const Format &format) { | ||||
| int64_t f = format; | int64_t f = format; | ||||
| this->AddAttr(kFormat, MakeValue(f)); | this->AddAttr(kFormat, MakeValue(f)); | ||||
| @@ -32,7 +68,13 @@ Format BiasAdd::get_format() const { | |||||
| return Format(GetValue<int64_t>(value_ptr)); | return Format(GetValue<int64_t>(value_ptr)); | ||||
| } | } | ||||
| void BiasAdd::Init(const Format &format) { this->set_format(format); } | void BiasAdd::Init(const Format &format) { this->set_format(format); } | ||||
| AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(BiasAddInferType(primitive, input_args), | |||||
| BiasAddInferShape(primitive, input_args)); | |||||
| } | |||||
| // Add | // Add | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); | REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,32 +30,31 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||||
| auto prim_name = conv_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | ||||
| if (conv_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | ||||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", | |||||
| w_shape[1], conv_prim->name()); | |||||
| auto out_channel = conv_prim->get_out_channel(); | |||||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual, | |||||
| "w_shape[1]", w_shape[1], prim_name); | |||||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); | |||||
| std::vector<int64_t> temp_w; | std::vector<int64_t> temp_w; | ||||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | ||||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, | |||||
| conv_prim->name()); | |||||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||||
| "w_shape[2:4]", temp_w, prim_name); | |||||
| auto kernel_size_h = w_shape[2]; | auto kernel_size_h = w_shape[2]; | ||||
| auto kernel_size_w = w_shape[3]; | auto kernel_size_w = w_shape[3]; | ||||
| auto stride = conv_prim->get_stride(); | |||||
| auto dilation = conv_prim->get_dilation(); | |||||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation)); | |||||
| auto stride_h = stride[2]; | auto stride_h = stride[2]; | ||||
| auto stride_w = stride[3]; | auto stride_w = stride[3]; | ||||
| auto dilation_h = dilation[2]; | auto dilation_h = dilation[2]; | ||||
| @@ -63,7 +62,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| int64_t h_out = -1; | int64_t h_out = -1; | ||||
| int64_t w_out = -1; | int64_t w_out = -1; | ||||
| std::vector<int64_t> pad_list(4, 0); | std::vector<int64_t> pad_list(4, 0); | ||||
| auto pad_mode = conv_prim->get_pad_mode(); | |||||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||||
| if (pad_mode == VALID) { | if (pad_mode == VALID) { | ||||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | ||||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | ||||
| @@ -81,20 +80,23 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| pad_list.emplace_back(pad_left); | pad_list.emplace_back(pad_left); | ||||
| pad_list.emplace_back(pad_needed_h - pad_left); | pad_list.emplace_back(pad_needed_h - pad_left); | ||||
| } else if (pad_mode == PAD) { | } else if (pad_mode == PAD) { | ||||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | |||||
| auto pad_top = conv_prim->get_pad()[0]; | |||||
| auto pad_bottom = conv_prim->get_pad()[1]; | |||||
| auto pad_right = conv_prim->get_pad()[2]; | |||||
| auto pad_left = conv_prim->get_pad()[3]; | |||||
| auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | |||||
| std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list)); | |||||
| auto pad_top = pad[0]; | |||||
| auto pad_bottom = pad[1]; | |||||
| auto pad_right = pad[2]; | |||||
| auto pad_left = pad[3]; | |||||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | ||||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | ||||
| h_out = floor(h_out); | h_out = floor(h_out); | ||||
| w_out = floor(w_out); | w_out = floor(w_out); | ||||
| } | } | ||||
| conv_prim->set_pad(pad_list); | |||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); | |||||
| primitive->AddAttr(kPadList, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true))); | |||||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | ||||
| if (conv_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | out_shape = {x_shape[0], h_out, w_out, out_channel}; | ||||
| } | } | ||||
| @@ -15,8 +15,10 @@ | |||||
| */ | */ | ||||
| #include "ops/fill.h" | #include "ops/fill.h" | ||||
| #include <memory> | |||||
| #include "ops/op_utils.h" | #include "ops/op_utils.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "utils/tensor_construct_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| @@ -38,7 +40,23 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||||
| valid_types.insert(kNumberTypeBool); | valid_types.insert(kNumberTypeBool); | ||||
| CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name); | CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name); | ||||
| auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue()); | auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue()); | ||||
| return std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape)); | |||||
| auto x_type = input_args[2]->BuildType(); | |||||
| auto x_type_id = x_type->type_id(); | |||||
| auto x_value = input_args[2]->BuildValue(); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape)); | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape); | |||||
| auto mem_size = IntToSize(tensor->ElementsNum()); | |||||
| if (x_type_id == kNumberTypeInt) { | |||||
| auto num = GetValue<int>(x_value); | |||||
| SetTensorData(tensor->data_c(), num, mem_size); | |||||
| } else if (x_type_id == kNumberTypeFloat || x_type_id == kNumberTypeFloat32) { | |||||
| auto num = GetValue<float>(x_value); | |||||
| SetTensorData(tensor->data_c(), num, mem_size); | |||||
| } else { | |||||
| MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString(); | |||||
| } | |||||
| abs->set_value(tensor); | |||||
| return abs; | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer); | ||||
| REGISTER_PRIMITIVE_C(kNameFill, Fill); | REGISTER_PRIMITIVE_C(kNameFill, Fill); | ||||
| @@ -23,15 +23,11 @@ namespace ops { | |||||
| AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto gather_prim = primitive->cast<PrimGatherPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(gather_prim); | |||||
| auto prim_name = gather_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); | ||||
| // Infer type | // Infer type | ||||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | ||||
| // auto dim_type = input_args[1]->BuildType(); | |||||
| // auto index_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; | std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; | ||||
| CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | ||||
| const std::set<TypeId> valid_index_types = {kNumberTypeInt32, kNumberTypeInt64}; | const std::set<TypeId> valid_index_types = {kNumberTypeInt32, kNumberTypeInt64}; | ||||
| @@ -27,10 +27,6 @@ namespace { | |||||
| abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto conv2d_backprop_filter_prim = primitive->cast<PrimConv2DBackpropFilterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv2d_backprop_filter_prim); | |||||
| // auto prim_name = conv2d_backprop_filter_prim->name(); | |||||
| auto out_put = input_args[2]->BuildValue(); | auto out_put = input_args[2]->BuildValue(); | ||||
| auto infer_shape = GetValue<std::vector<int64_t>>(out_put); | auto infer_shape = GetValue<std::vector<int64_t>>(out_put); | ||||
| return std::make_shared<abstract::Shape>(infer_shape); | return std::make_shared<abstract::Shape>(infer_shape); | ||||
| @@ -23,6 +23,62 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||||
| for (auto item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto doutput = input_args[0]; | |||||
| auto x_size = input_args[2]; | |||||
| auto x_size_value = x_size->GetValueTrack(); | |||||
| MS_EXCEPTION_IF_NULL(x_size); | |||||
| auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value); | |||||
| // infer dtype | |||||
| auto dtype = doutput->BuildType(); | |||||
| if (!dtype->isa<TensorType>()) { | |||||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); | |||||
| } | |||||
| auto input_tensor_type = dtype->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor_type); | |||||
| auto element = input_tensor_type->element(); | |||||
| // infer shape | |||||
| auto dout_shape = doutput->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(doutput); | |||||
| auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>(); | |||||
| auto dout_shape_norm = dout_shapeptr->shape(); | |||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||||
| auto pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList)); | |||||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||||
| if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) { | |||||
| primitive->AddAttr(kPadList, MakeValue(pad_list)); | |||||
| } else if (pad_mode == SAME) { | |||||
| auto stride_h = stride[0]; | |||||
| auto stride_w = stride[1]; | |||||
| auto kernel_h = kernel_size[0]; | |||||
| auto kernel_w = kernel_size[1]; | |||||
| auto dilation_h = dilation[2]; | |||||
| auto dilation_w = dilation[3]; | |||||
| auto pad_needed_h = (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]; | |||||
| pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h; | |||||
| auto pad_top = pad_needed_h / 2; | |||||
| auto pad_bottom = pad_needed_h - pad_top; | |||||
| auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2]; | |||||
| pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L; | |||||
| auto pad_left = pad_needed_w / 2; | |||||
| auto pad_right = pad_needed_w - pad_left; | |||||
| pad_list = {pad_top, pad_bottom, pad_left, pad_right}; | |||||
| } else if (pad_mode == PAD) { | |||||
| pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | |||||
| } | |||||
| primitive->AddAttr(kPadList, MakeValue(pad_list)); | |||||
| return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v)); | |||||
| } | |||||
| void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, | void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, | ||||
| const PadMode &pad_mode, const std::vector<int64_t> &pad, | const PadMode &pad_mode, const std::vector<int64_t> &pad, | ||||
| const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group, | const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group, | ||||
| @@ -140,6 +196,7 @@ std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const { | |||||
| auto value_ptr = GetAttr(kPadList); | auto value_ptr = GetAttr(kPadList); | ||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput); | REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -52,9 +52,7 @@ void MaxPoolGrad::set_strides(const std::vector<int64_t> &strides) { | |||||
| AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| auto MaxPoolGrad_prim = primitive->cast<PrimMaxPoolGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); | |||||
| auto op_name = MaxPoolGrad_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | ||||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | ||||
| auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | ||||
| @@ -21,6 +21,81 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| namespace { | |||||
| abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA)); | |||||
| auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB)); | |||||
| auto out_n = x_shape[0]; | |||||
| auto out_m = w_shape[1]; | |||||
| auto x_C = x_shape[1]; | |||||
| auto w_C = w_shape[0]; | |||||
| if (trans_a) { | |||||
| out_n = x_shape[1]; | |||||
| x_C = x_shape[0]; | |||||
| } | |||||
| if (trans_b) { | |||||
| out_m = w_shape[0]; | |||||
| w_C = w_shape[1]; | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("dim C is not equal", x_C, kEqual, w_C, prim_name); | |||||
| primitive->AddAttr("transpose_x1", MakeValue(trans_a)); | |||||
| primitive->AddAttr("transpose_x2", MakeValue(trans_b)); | |||||
| std::vector<int64_t> out_shape = {out_n, out_m}; | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, | |||||
| kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("x", input_args[0]->BuildType()); | |||||
| types.emplace("w", input_args[1]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||||
| if (infer_type == kNumberTypeInt8) { | |||||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||||
| } | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| void MatMul::Init(bool transpose_a, bool transpose_b) { | |||||
| set_transpose_a(transpose_a); | |||||
| set_transpose_b(transpose_b); | |||||
| } | |||||
| void MatMul::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); } | |||||
| void MatMul::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); } | |||||
| bool MatMul::get_transpose_a() const { | |||||
| auto value_ptr = GetAttr(kTransposeA); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool MatMul::get_transpose_b() const { | |||||
| auto value_ptr = GetAttr(kTransposeB); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| // Add | |||||
| AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(MatMulInferType(primitive, input_args), | |||||
| MatMulInferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| // Add | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(MatMul, prim::kPrimMatMul, MatMulInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameMatMul, MatMul); | REGISTER_PRIMITIVE_C(kNameMatMul, MatMul); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -94,22 +94,22 @@ void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pool_prim = primitive->cast<PrimMaxPoolPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||||
| auto op_name = pool_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | ||||
| auto kernel_size = pool_prim->get_kernel_size(); | |||||
| auto pad_mode = pool_prim->get_pad_mode(); | |||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||||
| auto pad_mode_value = (primitive->GetAttr(kPadMode)); | |||||
| PadMode pad_mode = PAD; | |||||
| pad_mode = PadMode(GetValue<int64_t>(pad_mode_value)); | |||||
| auto batch = in_shape[0]; | auto batch = in_shape[0]; | ||||
| auto channel = in_shape[1]; | auto channel = in_shape[1]; | ||||
| auto in_h = in_shape[2]; | auto in_h = in_shape[2]; | ||||
| auto in_w = in_shape[3]; | auto in_w = in_shape[3]; | ||||
| auto strides = pool_prim->get_strides(); | |||||
| auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides)); | |||||
| auto kernel_h = kernel_size[2]; | auto kernel_h = kernel_size[2]; | ||||
| auto kernel_w = kernel_size[3]; | auto kernel_w = kernel_size[3]; | ||||
| auto stride_h = strides[2]; | auto stride_h = strides[2]; | ||||
| @@ -117,14 +117,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| int64_t out_h = -1; | int64_t out_h = -1; | ||||
| int64_t out_w = -1; | int64_t out_w = -1; | ||||
| if (pad_mode == VALID) { | if (pad_mode == VALID) { | ||||
| out_h = ceil((in_h - (kernel_h - 1)) / stride_h); | |||||
| out_w = ceil((in_w - (kernel_w - 1)) / stride_w); | |||||
| out_h = ceil((in_h - (kernel_h - 1) + stride_h - 1) / stride_h); | |||||
| out_w = ceil((in_w - (kernel_w - 1) + stride_w - 1) / stride_w); | |||||
| } else if (pad_mode == SAME) { | } else if (pad_mode == SAME) { | ||||
| out_h = ceil(in_h / stride_h); | out_h = ceil(in_h / stride_h); | ||||
| out_w = ceil(in_w / stride_w); | out_w = ceil(in_w / stride_w); | ||||
| } | } | ||||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {batch, out_h, out_w, channel}; | out_shape = {batch, out_h, out_w, channel}; | ||||
| } | } | ||||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | ||||
| @@ -137,7 +137,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | ||||
| MS_LOG(EXCEPTION) << "nullptr"; | MS_LOG(EXCEPTION) << "nullptr"; | ||||
| } | } | ||||
| return input_args[0]->BuildType(); | |||||
| auto input_type = input_args[0]->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(input_type); | |||||
| auto input_tensor_type = input_type->cast<TensorTypePtr>(); | |||||
| if (input_tensor_type == nullptr) { | |||||
| MS_LOG_EXCEPTION << "The maxpool's input must be a tensor but got " << input_type->ToString(); | |||||
| } | |||||
| return input_tensor_type->element(); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -38,9 +38,9 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) { | for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) { | ||||
| args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]}); | args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]}); | ||||
| } | } | ||||
| std::set<TypePtr> template_type = {TypeIdToType(kNumberTypeBool)}; | |||||
| std::set<TypeId> template_type = {kNumberTypeBool}; | |||||
| for (auto item : common_valid_types) { | for (auto item : common_valid_types) { | ||||
| template_type.insert(TypeIdToType(item)); | |||||
| template_type.insert(item); | |||||
| } | } | ||||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); | CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); | ||||
| std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape(); | std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape(); | ||||
| @@ -30,13 +30,17 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| // infer shape | // infer shape | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto shape_prim = primitive->cast<PrimShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape_prim); | |||||
| auto op_name = shape_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | ||||
| // infer type | // infer type | ||||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| return std::make_shared<abstract::AbstractTensor>(x_type, in_shape); | |||||
| AbstractBasePtrList abs_list; | |||||
| std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), | |||||
| [](int64_t item) -> std::shared_ptr<abstract::AbstractScalar> { | |||||
| return std::make_shared<abstract::AbstractScalar>(item); | |||||
| }); | |||||
| auto abs = std::make_shared<abstract::AbstractTuple>(abs_list); | |||||
| abs->set_value(MakeValue(in_shape)); | |||||
| return abs; | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer); | ||||
| REGISTER_PRIMITIVE_C(kNameShape, Shape); | REGISTER_PRIMITIVE_C(kNameShape, Shape); | ||||
| @@ -60,7 +60,6 @@ AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameZerosLike, ZerosLike); | REGISTER_PRIMITIVE_C(kNameZerosLike, ZerosLike); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -486,7 +486,7 @@ void CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const Typ | |||||
| } | } | ||||
| void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, | void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, | ||||
| const std::set<TypePtr> &valid_values, | |||||
| const std::set<TypeId> &valid_values, | |||||
| const std::string &prim_name, const bool allow_mix) { | const std::string &prim_name, const bool allow_mix) { | ||||
| std::vector<std::map<std::string, TypePtr>> check_results; | std::vector<std::map<std::string, TypePtr>> check_results; | ||||
| for (auto &iter : args) { | for (auto &iter : args) { | ||||
| @@ -502,7 +502,7 @@ void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::stri | |||||
| } | } | ||||
| std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const std::map<std::string, TypePtr> &arg, | std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const std::map<std::string, TypePtr> &arg, | ||||
| const std::set<TypePtr> &valid_values, | |||||
| const std::set<TypeId> &valid_values, | |||||
| const std::string &prim_name) { | const std::string &prim_name) { | ||||
| std::string arg_key = arg.begin()->first; | std::string arg_key = arg.begin()->first; | ||||
| TypePtr arg_val = arg.begin()->second; | TypePtr arg_val = arg.begin()->second; | ||||
| @@ -512,15 +512,16 @@ std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const st | |||||
| arg_val = arg_val_->element(); | arg_val = arg_val_->element(); | ||||
| } | } | ||||
| auto it = valid_values.find(arg_val); | |||||
| auto it = valid_values.find(arg_val->type_id()); | |||||
| if (it == valid_values.end()) { | if (it == valid_values.end()) { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { "; | buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { "; | ||||
| for (auto valid_value : valid_values) { | for (auto valid_value : valid_values) { | ||||
| buffer << valid_value->ToString() << " },"; | |||||
| buffer << "but `" << arg_key << "`" | |||||
| << "is" << arg_val->ToString() << "."; | |||||
| buffer << TypeIdToType(valid_value)->ToString() << ","; | |||||
| } | } | ||||
| buffer << " },"; | |||||
| buffer << "but `" << arg_key << "`" | |||||
| << "is" << arg_val->ToString() << "."; | |||||
| MS_EXCEPTION(TypeError) << buffer.str(); | MS_EXCEPTION(TypeError) << buffer.str(); | ||||
| } | } | ||||
| return arg; | return arg; | ||||
| @@ -546,7 +547,7 @@ std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckTypeSame(const std::m | |||||
| except_flag = true; | except_flag = true; | ||||
| } | } | ||||
| if (except_flag || arg1_type != arg2_type) { | |||||
| if (except_flag || arg1_type->type_id() != arg2_type->type_id()) { | |||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << "For '" << prim_name << "'" | buffer << "For '" << prim_name << "'" | ||||
| << "type of " | << "type of " | ||||
| @@ -277,7 +277,7 @@ class CheckAndConvertUtils { | |||||
| static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set<TypePtr> &template_types, | static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set<TypePtr> &template_types, | ||||
| const std::string &prim_name); | const std::string &prim_name); | ||||
| static void CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, | static void CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, | ||||
| const std::set<TypePtr> &valid_values, const std::string &prim_name, | |||||
| const std::set<TypeId> &valid_values, const std::string &prim_name, | |||||
| bool allow_mix = false); | bool allow_mix = false); | ||||
| static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set<TypeId> &valid_type, | static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set<TypeId> &valid_type, | ||||
| const std::string &prim_name); | const std::string &prim_name); | ||||
| @@ -291,7 +291,7 @@ class CheckAndConvertUtils { | |||||
| private: | private: | ||||
| static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2); | static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2); | ||||
| static std::map<std::string, TypePtr> _CheckArgumentType(const std::map<std::string, TypePtr> &arg, | static std::map<std::string, TypePtr> _CheckArgumentType(const std::map<std::string, TypePtr> &arg, | ||||
| const std::set<TypePtr> &valid_values, | |||||
| const std::set<TypeId> &valid_values, | |||||
| const std::string &prim_name); | const std::string &prim_name); | ||||
| static std::map<std::string, TypePtr> _CheckTypeSame(const std::map<std::string, TypePtr> &arg1, | static std::map<std::string, TypePtr> _CheckTypeSame(const std::map<std::string, TypePtr> &arg1, | ||||
| const std::map<std::string, TypePtr> &arg2, | const std::map<std::string, TypePtr> &arg2, | ||||
| @@ -17,21 +17,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace { | |||||
| template <typename T> | |||||
| void SetTensorData(void *data, float num, size_t data_length) { | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| auto tensor_data = reinterpret_cast<T *>(data); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| for (size_t index = 0; index < data_length; ++index) { | |||||
| *tensor_data = num; | |||||
| ++tensor_data; | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { | tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | ||||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | ||||
| auto tensor_data = tensor->data_c(); | auto tensor_data = tensor->data_c(); | ||||
| char *data = reinterpret_cast<char *>(tensor_data); | char *data = reinterpret_cast<char *>(tensor_data); | ||||
| @@ -43,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std | |||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { | tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | ||||
| auto mem_size = IntToSize(tensor->ElementsNum()); | |||||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | |||||
| if (tensor->data_type() == kNumberTypeFloat32) { | if (tensor->data_type() == kNumberTypeFloat32) { | ||||
| SetTensorData<float>(tensor->data_c(), 1.0, mem_size); | |||||
| SetTensorData(tensor->data_c(), 1.0, mem_size); | |||||
| } else if (tensor->data_type() == kNumberTypeInt) { | } else if (tensor->data_type() == kNumberTypeInt) { | ||||
| SetTensorData<int>(tensor->data_c(), 1, mem_size); | |||||
| SetTensorData(tensor->data_c(), 1, mem_size); | |||||
| } | } | ||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| @@ -18,6 +18,16 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| template <typename T> | |||||
| void SetTensorData(void *data, T num, size_t data_length) { | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| auto tensor_data = reinterpret_cast<T *>(data); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| for (size_t index = 0; index < data_length; ++index) { | |||||
| *tensor_data = num; | |||||
| ++tensor_data; | |||||
| } | |||||
| } | |||||
| class TensorConstructUtils { | class TensorConstructUtils { | ||||
| public: | public: | ||||
| static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape); | static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape); | ||||