From 677e6158e31b1d6ebca1504a958c27a739261b96 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Tue, 11 Aug 2020 17:45:12 +0800 Subject: [PATCH] add infer function of primitive c --- mindspore/core/abstract/prim_nn.cc | 88 ------------------- mindspore/core/abstract/primitive_infer_map.h | 2 +- mindspore/core/c_ops/conv2d.cc | 86 ++++++++++++++++++ mindspore/core/c_ops/conv2d.h | 1 - mindspore/core/c_ops/primitive_c.cc | 36 ++++++++ mindspore/core/c_ops/primitive_c.h | 7 +- mindspore/core/utils/check_convert_utils.cc | 34 +++---- mindspore/core/utils/check_convert_utils.h | 6 +- 8 files changed, 142 insertions(+), 118 deletions(-) create mode 100644 mindspore/core/c_ops/primitive_c.cc diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index fbd5b3be76..54dfe0ecab 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -18,8 +18,6 @@ #include "abstract/utils.h" #include "abstract/param_validator.h" #include "utils/check_convert_utils.h" -#include "c_ops/conv2d.h" -#include "abstract/primitive_infer_map.h" namespace mindspore { namespace abstract { @@ -428,91 +426,5 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti return std::make_shared(std::make_shared(kAnyValue, kUInt8), std::make_shared(std::vector{shape_y})); } - -abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto conv_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(conv_prim); - auto prim_name = conv_prim->name(); - CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); - - CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", - w_shape[1], conv_prim->name()); - auto out_channel = conv_prim->GetOutputChannel(); - CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); - std::vector temp_w; - std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); - CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, - conv_prim->name()); - - auto kernel_size_h = w_shape[2]; - auto kernel_size_w = w_shape[3]; - auto stride = conv_prim->GetStride(); - auto dilation = conv_prim->GetDilation(); - auto stride_h = stride[2]; - auto stride_w = stride[3]; - auto dilation_h = dilation[2]; - auto dilation_w = dilation[3]; - int h_out = -1; - int w_out = -1; - std::vector pad_list(4, 0); - auto pad_mode = conv_prim->GetPadMode(); - if (pad_mode == "valid") { - h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); - w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); - } else if (pad_mode == "same") { - h_out = ceil(x_shape[2] / stride_h); - w_out = ceil(x_shape[3] / stride_w); - - auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); - pad_list.emplace_back(floor(pad_needed_h / 2)); - pad_list.emplace_back(pad_needed_h / 2); - auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); - auto pad_left = floor(pad_needed_w / 2); - pad_list.emplace_back(pad_left); - pad_list.emplace_back(pad_needed_h - pad_left); - } else if (pad_mode == "pad") { - std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); - auto pad_top = conv_prim->GetPad()[0]; - auto pad_bottom = conv_prim->GetPad()[1]; - auto pad_right = conv_prim->GetPad()[2]; - auto pad_left = conv_prim->GetPad()[3]; - - h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; - w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; - h_out = floor(h_out); - w_out = floor(w_out); - } - conv_prim->SetPadList(pad_list); - std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; - return std::make_shared(out_shape); -} - -TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector &input_args) { - CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; - std::map types; - types.emplace("x", input_args[0]->GetTypeTrack()); - types.emplace("w", input_args[1]->GetTypeTrack()); - CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (x_type == kNumberTypeInt8) { - return std::make_shared(TypeIdToType(kNumberTypeInt32)); - } - return std::make_shared(TypeIdToType(x_type)); -} -AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - return std::make_shared(Conv2dInferType(primitive, input_args), - Conv2dInferShape(primitive, input_args)->shape()); -} -REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index 87380c60e4..1274e1ac50 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -47,7 +47,7 @@ class RegisterStandardPrimitiveEvalHelper { }; #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ - static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl) + static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl) } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ diff --git a/mindspore/core/c_ops/conv2d.cc b/mindspore/core/c_ops/conv2d.cc index 77c6a61438..bf8520ce1d 100644 --- a/mindspore/core/c_ops/conv2d.cc +++ b/mindspore/core/c_ops/conv2d.cc @@ -23,6 +23,7 @@ #include #include #include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace { @@ -36,6 +37,84 @@ constexpr auto kGroup = "group"; constexpr auto kOutputChannel = "output channel"; constexpr auto kPadList = "pad_list"; constexpr auto kConv2DName = "Conv2D"; +abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto conv_prim = std::dynamic_pointer_cast(primitive); + MS_EXCEPTION_IF_NULL(conv_prim); + auto prim_name = conv_prim->name(); + CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); + + CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", + w_shape[1], conv_prim->name()); + auto out_channel = conv_prim->GetOutputChannel(); + CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); + std::vector temp_w; + std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); + CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, + conv_prim->name()); + + auto kernel_size_h = w_shape[2]; + auto kernel_size_w = w_shape[3]; + auto stride = conv_prim->GetStride(); + auto dilation = conv_prim->GetDilation(); + auto stride_h = stride[2]; + auto stride_w = stride[3]; + auto dilation_h = dilation[2]; + auto dilation_w = dilation[3]; + int h_out = -1; + int w_out = -1; + std::vector pad_list(4, 0); + auto pad_mode = conv_prim->GetPadMode(); + if (pad_mode == "valid") { + h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); + w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); + } else if (pad_mode == "same") { + h_out = ceil(x_shape[2] / stride_h); + w_out = ceil(x_shape[3] / stride_w); + + auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); + pad_list.emplace_back(floor(pad_needed_h / 2)); + pad_list.emplace_back(pad_needed_h / 2); + auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); + auto pad_left = floor(pad_needed_w / 2); + pad_list.emplace_back(pad_left); + pad_list.emplace_back(pad_needed_h - pad_left); + } else if (pad_mode == "pad") { + std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); + auto pad_top = conv_prim->GetPad()[0]; + auto pad_bottom = conv_prim->GetPad()[1]; + auto pad_right = conv_prim->GetPad()[2]; + auto pad_left = conv_prim->GetPad()[3]; + + h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; + w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; + h_out = floor(h_out); + w_out = floor(w_out); + } + conv_prim->SetPadList(pad_list); + std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; + return std::make_shared(out_shape); +} + +TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector &input_args) { + CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + std::map types; + types.emplace("x", input_args[0]->BuildType()); + types.emplace("w", input_args[1]->BuildType()); + auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + if (infer_type == kNumberTypeInt8) { + return std::make_shared(TypeIdToType(kNumberTypeInt32)); + } + return TypeIdToType(infer_type); +} } // namespace Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } @@ -105,4 +184,11 @@ void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } void Conv2d::SetPadList(const std::vector &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } + +AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(Conv2dInferType(primitive, input_args), + Conv2dInferShape(primitive, input_args)->shape()); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); } // namespace mindspore diff --git a/mindspore/core/c_ops/conv2d.h b/mindspore/core/c_ops/conv2d.h index 910fad18af..89b5259fc7 100644 --- a/mindspore/core/c_ops/conv2d.h +++ b/mindspore/core/c_ops/conv2d.h @@ -55,5 +55,4 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive const std::vector &input_args); using PrimConv2dPtr = std::shared_ptr; } // namespace mindspore - #endif // MINDSPORE_CORE_C_OPS_CONV2D_H_ diff --git a/mindspore/core/c_ops/primitive_c.cc b/mindspore/core/c_ops/primitive_c.cc new file mode 100644 index 0000000000..e9d98a6567 --- /dev/null +++ b/mindspore/core/c_ops/primitive_c.cc @@ -0,0 +1,36 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "c_ops/primitive_c.h" +#include +#include +namespace mindspore { +void PrimitiveC::InitIOName(const std::vector &inputs_name, const std::vector &outputs_name) { + this->AddAttr("input_names", MakeValue(inputs_name)); + this->AddAttr("output_names", MakeValue(outputs_name)); +} + +AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { + auto infer_map = abstract::GetPrimitiveToEvalImplMap(); + auto iter = infer_map.find(std::make_shared(this->name())); + if (iter == infer_map.end()) { + MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; + } + auto infer_function = iter->second.impl_; + return infer_function(nullptr, shared_from_base(), abstract_list); +} +} // namespace mindspore diff --git a/mindspore/core/c_ops/primitive_c.h b/mindspore/core/c_ops/primitive_c.h index 501f32f964..a006eb9aca 100644 --- a/mindspore/core/c_ops/primitive_c.h +++ b/mindspore/core/c_ops/primitive_c.h @@ -21,17 +21,16 @@ #include #include #include "ir/primitive.h" +#include "abstract/primitive_infer_map.h" #include "ir/value.h" namespace mindspore { class PrimitiveC : public Primitive { public: explicit PrimitiveC(const std::string &name) : Primitive(name) {} + AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list); protected: - void InitIOName(const std::vector &inputs_name, const std::vector &outputs_name) { - this->AddAttr("input_names", MakeValue(inputs_name)); - this->AddAttr("output_names", MakeValue(outputs_name)); - } + void InitIOName(const std::vector &inputs_name, const std::vector &outputs_name); }; } // namespace mindspore #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 140f3efc0d..1c7fe9abf9 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -39,8 +39,8 @@ const std::map)>> kCom [](int num1, std::pair range) -> bool { return num1 >= range.first && num1 <= range.second; }}}; const std::map kCompareToString = { - {kEqual, "equal"}, {kNotEqual, "not equal"}, {kLessThan, "less than"}, - {kLessEqual, "less eqaul"}, {kGreaterThan, "greater than"}, {kGreaterEqual, "greate equal"}}; + {kEqual, "equal "}, {kNotEqual, "not equal "}, {kLessThan, "less than "}, + {kLessEqual, "less eqaul "}, {kGreaterThan, "greater than "}, {kGreaterEqual, "greate equal "}}; const std::map> kCompareRangeToString = { {kIncludeNeither, {"in (", ")"}}, @@ -162,16 +162,6 @@ std::vector CheckAndConvertUtils::ConvertShapePtrToShape(const std::string return shape_element->shape(); } -TypeId CheckAndConvertUtils::ConvertTypePtrToTypeId(const string &arg_name, const TypePtr &type_ptr, - const string &prim_name) { - MS_EXCEPTION_IF_NULL(type_ptr); - if (!type_ptr->isa() || !type_ptr->isa()) { - MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << type_ptr->ToString() - << "should be a common type!(tensor_type && numbertype)"; - } - return type_ptr->type_id(); -} - void CheckAndConvertUtils::Check(const string &arg_name, int arg_value, CompareEnum compare_type, const string &value_name, int value, const string &prim_name, ExceptionType exception_type) { @@ -231,11 +221,10 @@ void CheckAndConvertUtils::Check(const string &arg_name, const std::vector MS_EXCEPTION(exception_type) << buffer.str(); } -void CheckAndConvertUtils::CheckTensorTypeSame(const std::map &types, - const std::set &check_list, const std::string &prim_name) { +TypeId CheckAndConvertUtils::CheckTensorTypeSame(const std::map &types, + const std::set &check_list, const std::string &prim_name) { if (types.empty()) { - MS_LOG(WARNING) << "Tryinh to use the function to check a empty types map!"; - return; + MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!"; } std::set types_id; std::ostringstream buffer; @@ -246,7 +235,11 @@ void CheckAndConvertUtils::CheckTensorTypeSame(const std::mapToString(); } - types_id.emplace(type.second->type_id()); + auto tensor_type = type.second->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto element = tensor_type->element(); + MS_EXCEPTION_IF_NULL(element); + types_id.emplace(element->type_id()); } if (types_id.size() > 1) { buffer << "'s input type is not same : "; @@ -255,16 +248,17 @@ void CheckAndConvertUtils::CheckTensorTypeSame(const std::mapToString() << " ,"; } buffer << "] , but got " << types.begin()->second->ToString(); } + MS_EXCEPTION(TypeError) << buffer.str(); } - MS_EXCEPTION(TypeError) << buffer.str(); + return *types_id.begin(); } } // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index f61f22d19b..77b78d2189 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -55,15 +55,13 @@ class CheckAndConvertUtils { const std::pair &range, const std::string &prim_name); static std::vector ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, const std::string &prim_name); - static TypeId ConvertTypePtrToTypeId(const std::string &arg_name, const TypePtr &type_ptr, - const std::string &prim_name); static void Check(const std::string &arg_name, int arg_value, CompareEnum compare_type, const std::string &value_name, int value, const std::string &prim_name = "", ExceptionType exception_type = ValueError); static void Check(const std::string &arg_name, const std::vector &arg_value, CompareEnum compare_type, const std::string &value_name, const std::vector &value, const std::string &prim_name = "", ExceptionType exception_type = ValueError); - static void CheckTensorTypeSame(const std::map &types, const std::set &check_list, - const std::string &prim_name); + static TypeId CheckTensorTypeSame(const std::map &types, const std::set &check_list, + const std::string &prim_name); private: static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2);