From: @lianliguang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -308,6 +308,10 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -202,6 +202,23 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti | |||
| return std::make_shared<AbstractTuple>(elements); | |||
| } | |||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto is_grad = GetValue<bool>(primitive->GetAttr("is_grad")); | |||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||
| std::shared_ptr<BaseShape> shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{}); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| if (is_grad) { | |||
| shape = args_spec_list[0]->BuildShape(); | |||
| } | |||
| auto type = args_spec_list[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| auto type_tensor = type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(type_tensor); | |||
| return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape); | |||
| } | |||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | |||
| @@ -559,5 +559,21 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con | |||
| ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape); | |||
| return std::make_shared<AbstractTensor>(input->element(), shape); | |||
| } | |||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| auto type = args_spec_list[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| auto tensor_type = type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto value = tensor_type->element(); | |||
| auto abstract = std::make_shared<abstract::AbstractType>(value); | |||
| abstract->set_value(value); | |||
| return abstract; | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -193,6 +193,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||
| {prim::kPrimCast, {InferImplCast, true}}, | |||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||
| {prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, | |||
| {prim::kPrimDType, {InferImplDType, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| @@ -521,6 +521,7 @@ inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFus | |||
| inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); | |||
| inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | |||
| inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); | |||
| inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType"); | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| @@ -56,32 +56,29 @@ float ApplyMomentum::get_gradient_scale() const { | |||
| AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>(); | |||
| MS_EXCEPTION_IF_NULL(momentum_prim); | |||
| auto prim_name = momentum_prim->name(); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | |||
| // Infer shape | |||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); | |||
| // Infer type | |||
| auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto v_tensor_type = input_args[0]->BuildType(); | |||
| auto a_tensor_type = input_args[1]->BuildType(); | |||
| auto l_type = input_args[2]->BuildType(); | |||
| auto g_type = input_args[3]->BuildType(); | |||
| auto m_type = input_args[4]->BuildType(); | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name); | |||
| const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32), | |||
| TypeIdToType(kNumberTypeFloat64)}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); | |||
| std::map<std::string, TypePtr> args; | |||
| args.insert({"l_type", l_type}); | |||
| args.insert({"g_type", g_type}); | |||
| args.insert({"m_type", m_type}); | |||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(g_type, v_shape); | |||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name); | |||
| auto g_type_tensor = g_type->cast<TensorTypePtr>(); | |||
| auto element = g_type_tensor->element(); | |||
| return std::make_shared<abstract::AbstractTensor>(element, v_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); | |||
| REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | |||
| @@ -61,7 +61,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||
| condition = input_args[0]->BuildType(); | |||
| } | |||
| std::vector<int64_t> output_shape = {1}; | |||
| std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)}; | |||
| std::set<TypeId> local_bool = {kNumberTypeBool}; | |||
| std::map<std::string, TypePtr> args = {{"condition", condition}}; | |||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); | |||
| auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements(); | |||
| @@ -23,6 +23,42 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| // Add | |||
| namespace { | |||
| abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| // check | |||
| CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); | |||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||
| auto x_channel = x_shape[1]; | |||
| if (format != NCHW) { | |||
| x_channel = x_shape[x_shape.size() - 1]; | |||
| } | |||
| CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, prim_name); | |||
| std::vector<int64_t> out_shape = x_shape; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto prim_name = prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("input_x", input_args[0]->BuildType()); | |||
| types.emplace("bias", input_args[1]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| void BiasAdd::set_format(const Format &format) { | |||
| int64_t f = format; | |||
| this->AddAttr(kFormat, MakeValue(f)); | |||
| @@ -32,7 +68,13 @@ Format BiasAdd::get_format() const { | |||
| return Format(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| void BiasAdd::Init(const Format &format) { this->set_format(format); } | |||
| AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(BiasAddInferType(primitive, input_args), | |||
| BiasAddInferShape(primitive, input_args)); | |||
| } | |||
| // Add | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -30,32 +30,31 @@ namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||
| auto prim_name = conv_prim->name(); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||
| if (conv_prim->get_format() == NHWC) { | |||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||
| if (format == NHWC) { | |||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", | |||
| w_shape[1], conv_prim->name()); | |||
| auto out_channel = conv_prim->get_out_channel(); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual, | |||
| "w_shape[1]", w_shape[1], prim_name); | |||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); | |||
| std::vector<int64_t> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, | |||
| conv_prim->name()); | |||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||
| "w_shape[2:4]", temp_w, prim_name); | |||
| auto kernel_size_h = w_shape[2]; | |||
| auto kernel_size_w = w_shape[3]; | |||
| auto stride = conv_prim->get_stride(); | |||
| auto dilation = conv_prim->get_dilation(); | |||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation)); | |||
| auto stride_h = stride[2]; | |||
| auto stride_w = stride[3]; | |||
| auto dilation_h = dilation[2]; | |||
| @@ -63,7 +62,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| int64_t h_out = -1; | |||
| int64_t w_out = -1; | |||
| std::vector<int64_t> pad_list(4, 0); | |||
| auto pad_mode = conv_prim->get_pad_mode(); | |||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||
| if (pad_mode == VALID) { | |||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||
| @@ -81,20 +80,23 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| pad_list.emplace_back(pad_left); | |||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||
| } else if (pad_mode == PAD) { | |||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->get_pad()[0]; | |||
| auto pad_bottom = conv_prim->get_pad()[1]; | |||
| auto pad_right = conv_prim->get_pad()[2]; | |||
| auto pad_left = conv_prim->get_pad()[3]; | |||
| auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | |||
| std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list)); | |||
| auto pad_top = pad[0]; | |||
| auto pad_bottom = pad[1]; | |||
| auto pad_right = pad[2]; | |||
| auto pad_left = pad[3]; | |||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | |||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| conv_prim->set_pad(pad_list); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); | |||
| primitive->AddAttr(kPadList, | |||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true))); | |||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||
| if (conv_prim->get_format() == NHWC) { | |||
| if (format == NHWC) { | |||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | |||
| } | |||
| @@ -15,8 +15,10 @@ | |||
| */ | |||
| #include "ops/fill.h" | |||
| #include <memory> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "utils/tensor_construct_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| @@ -38,7 +40,23 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||
| valid_types.insert(kNumberTypeBool); | |||
| CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name); | |||
| auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue()); | |||
| return std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape)); | |||
| auto x_type = input_args[2]->BuildType(); | |||
| auto x_type_id = x_type->type_id(); | |||
| auto x_value = input_args[2]->BuildValue(); | |||
| auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape)); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape); | |||
| auto mem_size = IntToSize(tensor->ElementsNum()); | |||
| if (x_type_id == kNumberTypeInt) { | |||
| auto num = GetValue<int>(x_value); | |||
| SetTensorData(tensor->data_c(), num, mem_size); | |||
| } else if (x_type_id == kNumberTypeFloat || x_type_id == kNumberTypeFloat32) { | |||
| auto num = GetValue<float>(x_value); | |||
| SetTensorData(tensor->data_c(), num, mem_size); | |||
| } else { | |||
| MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString(); | |||
| } | |||
| abs->set_value(tensor); | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer); | |||
| REGISTER_PRIMITIVE_C(kNameFill, Fill); | |||
| @@ -23,15 +23,11 @@ namespace ops { | |||
| AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto gather_prim = primitive->cast<PrimGatherPtr>(); | |||
| MS_EXCEPTION_IF_NULL(gather_prim); | |||
| auto prim_name = gather_prim->name(); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); | |||
| // Infer type | |||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| // auto dim_type = input_args[1]->BuildType(); | |||
| // auto index_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; | |||
| CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | |||
| const std::set<TypeId> valid_index_types = {kNumberTypeInt32, kNumberTypeInt64}; | |||
| @@ -27,10 +27,6 @@ namespace { | |||
| abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv2d_backprop_filter_prim = primitive->cast<PrimConv2DBackpropFilterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv2d_backprop_filter_prim); | |||
| // auto prim_name = conv2d_backprop_filter_prim->name(); | |||
| auto out_put = input_args[2]->BuildValue(); | |||
| auto infer_shape = GetValue<std::vector<int64_t>>(out_put); | |||
| return std::make_shared<abstract::Shape>(infer_shape); | |||
| @@ -23,6 +23,62 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||
| for (auto item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto doutput = input_args[0]; | |||
| auto x_size = input_args[2]; | |||
| auto x_size_value = x_size->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(x_size); | |||
| auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value); | |||
| // infer dtype | |||
| auto dtype = doutput->BuildType(); | |||
| if (!dtype->isa<TensorType>()) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); | |||
| } | |||
| auto input_tensor_type = dtype->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_tensor_type); | |||
| auto element = input_tensor_type->element(); | |||
| // infer shape | |||
| auto dout_shape = doutput->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(doutput); | |||
| auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>(); | |||
| auto dout_shape_norm = dout_shapeptr->shape(); | |||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||
| auto pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList)); | |||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||
| if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) { | |||
| primitive->AddAttr(kPadList, MakeValue(pad_list)); | |||
| } else if (pad_mode == SAME) { | |||
| auto stride_h = stride[0]; | |||
| auto stride_w = stride[1]; | |||
| auto kernel_h = kernel_size[0]; | |||
| auto kernel_w = kernel_size[1]; | |||
| auto dilation_h = dilation[2]; | |||
| auto dilation_w = dilation[3]; | |||
| auto pad_needed_h = (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]; | |||
| pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h; | |||
| auto pad_top = pad_needed_h / 2; | |||
| auto pad_bottom = pad_needed_h - pad_top; | |||
| auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2]; | |||
| pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L; | |||
| auto pad_left = pad_needed_w / 2; | |||
| auto pad_right = pad_needed_w - pad_left; | |||
| pad_list = {pad_top, pad_bottom, pad_left, pad_right}; | |||
| } else if (pad_mode == PAD) { | |||
| pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | |||
| } | |||
| primitive->AddAttr(kPadList, MakeValue(pad_list)); | |||
| return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v)); | |||
| } | |||
| void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, | |||
| const PadMode &pad_mode, const std::vector<int64_t> &pad, | |||
| const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group, | |||
| @@ -140,6 +196,7 @@ std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const { | |||
| auto value_ptr = GetAttr(kPadList); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer); | |||
| REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -52,9 +52,7 @@ void MaxPoolGrad::set_strides(const std::vector<int64_t> &strides) { | |||
| AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto MaxPoolGrad_prim = primitive->cast<PrimMaxPoolGradPtr>(); | |||
| MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); | |||
| auto op_name = MaxPoolGrad_prim->name(); | |||
| auto op_name = primitive->name(); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | |||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | |||
| auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | |||
| @@ -21,6 +21,81 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||
| auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA)); | |||
| auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB)); | |||
| auto out_n = x_shape[0]; | |||
| auto out_m = w_shape[1]; | |||
| auto x_C = x_shape[1]; | |||
| auto w_C = w_shape[0]; | |||
| if (trans_a) { | |||
| out_n = x_shape[1]; | |||
| x_C = x_shape[0]; | |||
| } | |||
| if (trans_b) { | |||
| out_m = w_shape[0]; | |||
| w_C = w_shape[1]; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("dim C is not equal", x_C, kEqual, w_C, prim_name); | |||
| primitive->AddAttr("transpose_x1", MakeValue(trans_a)); | |||
| primitive->AddAttr("transpose_x2", MakeValue(trans_b)); | |||
| std::vector<int64_t> out_shape = {out_n, out_m}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, | |||
| kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| types.emplace("w", input_args[1]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| if (infer_type == kNumberTypeInt8) { | |||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||
| } | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| void MatMul::Init(bool transpose_a, bool transpose_b) { | |||
| set_transpose_a(transpose_a); | |||
| set_transpose_b(transpose_b); | |||
| } | |||
| void MatMul::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); } | |||
| void MatMul::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); } | |||
| bool MatMul::get_transpose_a() const { | |||
| auto value_ptr = GetAttr(kTransposeA); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| bool MatMul::get_transpose_b() const { | |||
| auto value_ptr = GetAttr(kTransposeB); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| // Add | |||
| AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(MatMulInferType(primitive, input_args), | |||
| MatMulInferShape(primitive, input_args)->shape()); | |||
| } | |||
| // Add | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MatMul, prim::kPrimMatMul, MatMulInfer); | |||
| REGISTER_PRIMITIVE_C(kNameMatMul, MatMul); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -94,22 +94,22 @@ void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto pool_prim = primitive->cast<PrimMaxPoolPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||
| auto op_name = pool_prim->name(); | |||
| auto op_name = primitive->name(); | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||
| if (pool_prim->get_format() == NHWC) { | |||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||
| if (format == NHWC) { | |||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | |||
| auto kernel_size = pool_prim->get_kernel_size(); | |||
| auto pad_mode = pool_prim->get_pad_mode(); | |||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||
| auto pad_mode_value = (primitive->GetAttr(kPadMode)); | |||
| PadMode pad_mode = PAD; | |||
| pad_mode = PadMode(GetValue<int64_t>(pad_mode_value)); | |||
| auto batch = in_shape[0]; | |||
| auto channel = in_shape[1]; | |||
| auto in_h = in_shape[2]; | |||
| auto in_w = in_shape[3]; | |||
| auto strides = pool_prim->get_strides(); | |||
| auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides)); | |||
| auto kernel_h = kernel_size[2]; | |||
| auto kernel_w = kernel_size[3]; | |||
| auto stride_h = strides[2]; | |||
| @@ -117,14 +117,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| int64_t out_h = -1; | |||
| int64_t out_w = -1; | |||
| if (pad_mode == VALID) { | |||
| out_h = ceil((in_h - (kernel_h - 1)) / stride_h); | |||
| out_w = ceil((in_w - (kernel_w - 1)) / stride_w); | |||
| out_h = ceil((in_h - (kernel_h - 1) + stride_h - 1) / stride_h); | |||
| out_w = ceil((in_w - (kernel_w - 1) + stride_w - 1) / stride_w); | |||
| } else if (pad_mode == SAME) { | |||
| out_h = ceil(in_h / stride_h); | |||
| out_w = ceil(in_w / stride_w); | |||
| } | |||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | |||
| if (pool_prim->get_format() == NHWC) { | |||
| if (format == NHWC) { | |||
| out_shape = {batch, out_h, out_w, channel}; | |||
| } | |||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | |||
| @@ -137,7 +137,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| return input_args[0]->BuildType(); | |||
| auto input_type = input_args[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(input_type); | |||
| auto input_tensor_type = input_type->cast<TensorTypePtr>(); | |||
| if (input_tensor_type == nullptr) { | |||
| MS_LOG_EXCEPTION << "The maxpool's input must be a tensor but got " << input_type->ToString(); | |||
| } | |||
| return input_tensor_type->element(); | |||
| } | |||
| } // namespace | |||
| @@ -38,9 +38,9 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||
| for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) { | |||
| args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]}); | |||
| } | |||
| std::set<TypePtr> template_type = {TypeIdToType(kNumberTypeBool)}; | |||
| std::set<TypeId> template_type = {kNumberTypeBool}; | |||
| for (auto item : common_valid_types) { | |||
| template_type.insert(TypeIdToType(item)); | |||
| template_type.insert(item); | |||
| } | |||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); | |||
| std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape(); | |||
| @@ -30,13 +30,17 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| // infer shape | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto shape_prim = primitive->cast<PrimShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_prim); | |||
| auto op_name = shape_prim->name(); | |||
| auto op_name = primitive->name(); | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||
| // infer type | |||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| return std::make_shared<abstract::AbstractTensor>(x_type, in_shape); | |||
| AbstractBasePtrList abs_list; | |||
| std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), | |||
| [](int64_t item) -> std::shared_ptr<abstract::AbstractScalar> { | |||
| return std::make_shared<abstract::AbstractScalar>(item); | |||
| }); | |||
| auto abs = std::make_shared<abstract::AbstractTuple>(abs_list); | |||
| abs->set_value(MakeValue(in_shape)); | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer); | |||
| REGISTER_PRIMITIVE_C(kNameShape, Shape); | |||
| @@ -60,7 +60,6 @@ AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const Primit | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer); | |||
| REGISTER_PRIMITIVE_C(kNameZerosLike, ZerosLike); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -486,7 +486,7 @@ void CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const Typ | |||
| } | |||
| void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, | |||
| const std::set<TypePtr> &valid_values, | |||
| const std::set<TypeId> &valid_values, | |||
| const std::string &prim_name, const bool allow_mix) { | |||
| std::vector<std::map<std::string, TypePtr>> check_results; | |||
| for (auto &iter : args) { | |||
| @@ -502,7 +502,7 @@ void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::stri | |||
| } | |||
| std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const std::map<std::string, TypePtr> &arg, | |||
| const std::set<TypePtr> &valid_values, | |||
| const std::set<TypeId> &valid_values, | |||
| const std::string &prim_name) { | |||
| std::string arg_key = arg.begin()->first; | |||
| TypePtr arg_val = arg.begin()->second; | |||
| @@ -512,15 +512,16 @@ std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const st | |||
| arg_val = arg_val_->element(); | |||
| } | |||
| auto it = valid_values.find(arg_val); | |||
| auto it = valid_values.find(arg_val->type_id()); | |||
| if (it == valid_values.end()) { | |||
| std::ostringstream buffer; | |||
| buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { "; | |||
| for (auto valid_value : valid_values) { | |||
| buffer << valid_value->ToString() << " },"; | |||
| buffer << "but `" << arg_key << "`" | |||
| << "is" << arg_val->ToString() << "."; | |||
| buffer << TypeIdToType(valid_value)->ToString() << ","; | |||
| } | |||
| buffer << " },"; | |||
| buffer << "but `" << arg_key << "`" | |||
| << "is" << arg_val->ToString() << "."; | |||
| MS_EXCEPTION(TypeError) << buffer.str(); | |||
| } | |||
| return arg; | |||
| @@ -546,7 +547,7 @@ std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckTypeSame(const std::m | |||
| except_flag = true; | |||
| } | |||
| if (except_flag || arg1_type != arg2_type) { | |||
| if (except_flag || arg1_type->type_id() != arg2_type->type_id()) { | |||
| std::ostringstream buffer; | |||
| buffer << "For '" << prim_name << "'" | |||
| << "type of " | |||
| @@ -277,7 +277,7 @@ class CheckAndConvertUtils { | |||
| static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set<TypePtr> &template_types, | |||
| const std::string &prim_name); | |||
| static void CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, | |||
| const std::set<TypePtr> &valid_values, const std::string &prim_name, | |||
| const std::set<TypeId> &valid_values, const std::string &prim_name, | |||
| bool allow_mix = false); | |||
| static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set<TypeId> &valid_type, | |||
| const std::string &prim_name); | |||
| @@ -291,7 +291,7 @@ class CheckAndConvertUtils { | |||
| private: | |||
| static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2); | |||
| static std::map<std::string, TypePtr> _CheckArgumentType(const std::map<std::string, TypePtr> &arg, | |||
| const std::set<TypePtr> &valid_values, | |||
| const std::set<TypeId> &valid_values, | |||
| const std::string &prim_name); | |||
| static std::map<std::string, TypePtr> _CheckTypeSame(const std::map<std::string, TypePtr> &arg1, | |||
| const std::map<std::string, TypePtr> &arg2, | |||
| @@ -17,21 +17,8 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace { | |||
| template <typename T> | |||
| void SetTensorData(void *data, float num, size_t data_length) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto tensor_data = reinterpret_cast<T *>(data); | |||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||
| for (size_t index = 0; index < data_length; ++index) { | |||
| *tensor_data = num; | |||
| ++tensor_data; | |||
| } | |||
| } | |||
| } // namespace | |||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | |||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | |||
| auto tensor_data = tensor->data_c(); | |||
| char *data = reinterpret_cast<char *>(tensor_data); | |||
| @@ -43,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std | |||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | |||
| auto mem_size = IntToSize(tensor->ElementsNum()); | |||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | |||
| if (tensor->data_type() == kNumberTypeFloat32) { | |||
| SetTensorData<float>(tensor->data_c(), 1.0, mem_size); | |||
| SetTensorData(tensor->data_c(), 1.0, mem_size); | |||
| } else if (tensor->data_type() == kNumberTypeInt) { | |||
| SetTensorData<int>(tensor->data_c(), 1, mem_size); | |||
| SetTensorData(tensor->data_c(), 1, mem_size); | |||
| } | |||
| return tensor; | |||
| } | |||
| @@ -18,6 +18,16 @@ | |||
| #include <vector> | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| template <typename T> | |||
| void SetTensorData(void *data, T num, size_t data_length) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto tensor_data = reinterpret_cast<T *>(data); | |||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||
| for (size_t index = 0; index < data_length; ++index) { | |||
| *tensor_data = num; | |||
| ++tensor_data; | |||
| } | |||
| } | |||
| class TensorConstructUtils { | |||
| public: | |||
| static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape); | |||