| @@ -149,7 +149,9 @@ add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/utils util) | |||||
| list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_core_utils_obj>) | list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_core_utils_obj>) | ||||
| add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir) | add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir) | ||||
| list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_ir_obj>) | list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_ir_obj>) | ||||
| add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input ) | |||||
| add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/c_ops c_ops) | |||||
| list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_c_ops_obj>) | |||||
| add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj _mindspore_c_ops_obj proto_input) | |||||
| set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) | set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) | ||||
| add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) | add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) | ||||
| @@ -0,0 +1,2 @@ | |||||
| file(GLOB_RECURSE _C_OPS_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| add_library(_mindspore_c_ops_obj OBJECT ${_C_OPS_ALL_SRC_FILES}) | |||||
| @@ -0,0 +1,139 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/conv2d.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace { | |||||
| using PrimConv2dPtr = std::shared_ptr<Conv2d>; | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||||
| auto prim_name = conv_prim->name(); | |||||
| CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", | |||||
| w_shape[1], conv_prim->name()); | |||||
| auto out_channel = conv_prim->GetOutputChannel(); | |||||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||||
| std::vector<int> temp_w; | |||||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, | |||||
| conv_prim->name()); | |||||
| auto kernel_size_h = w_shape[2]; | |||||
| auto kernel_size_w = w_shape[3]; | |||||
| auto stride = conv_prim->GetStride(); | |||||
| auto dilation = conv_prim->GetDilation(); | |||||
| auto stride_h = stride[2]; | |||||
| auto stride_w = stride[3]; | |||||
| auto dilation_h = dilation[2]; | |||||
| auto dilation_w = dilation[3]; | |||||
| int h_out = -1; | |||||
| int w_out = -1; | |||||
| std::vector<int> pad_list(4, 0); | |||||
| auto pad_mode = conv_prim->GetPadMode(); | |||||
| if (pad_mode == "valid") { | |||||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||||
| } else if (pad_mode == "same") { | |||||
| h_out = ceil(x_shape[2] / stride_h); | |||||
| w_out = ceil(x_shape[3] / stride_w); | |||||
| auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); | |||||
| pad_list.emplace_back(floor(pad_needed_h / 2)); | |||||
| pad_list.emplace_back(pad_needed_h / 2); | |||||
| auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); | |||||
| auto pad_left = floor(pad_needed_w / 2); | |||||
| pad_list.emplace_back(pad_left); | |||||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||||
| } else if (pad_mode == "pad") { | |||||
| std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); | |||||
| auto pad_top = conv_prim->GetPad()[0]; | |||||
| auto pad_bottom = conv_prim->GetPad()[1]; | |||||
| auto pad_right = conv_prim->GetPad()[2]; | |||||
| auto pad_left = conv_prim->GetPad()[3]; | |||||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | |||||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | |||||
| h_out = floor(h_out); | |||||
| w_out = floor(w_out); | |||||
| } | |||||
| conv_prim->SetPadList(pad_list); | |||||
| std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("x", input_args[0]->GetTypeTrack()); | |||||
| types.emplace("w", input_args[1]->GetTypeTrack()); | |||||
| CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||||
| if (x_type == kNumberTypeInt8) { | |||||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||||
| } | |||||
| return std::make_shared<TensorType>(TypeIdToType(x_type)); | |||||
| } | |||||
| } // namespace | |||||
| void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode, | |||||
| const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, | |||||
| int group) { | |||||
| auto prim_name = this->name(); | |||||
| this->AddAttr("data_format", MakeValue("NCHW")); | |||||
| this->AddAttr("offset_a", MakeValue(0)); | |||||
| this->SetKernelSize(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | |||||
| this->SetStride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true)); | |||||
| this->SetDilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true)); | |||||
| this->SetPadMode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); | |||||
| CheckAndConvertUtils::CheckInteger("pad size", pad.size(), kEqual, 4, prim_name); | |||||
| if (pad_mode == "pad") { | |||||
| for (auto item : pad) { | |||||
| CheckAndConvertUtils::Check("pad item", item, kGreaterEqual, "zeros list", 0, prim_name); | |||||
| } | |||||
| } else { | |||||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros list", {0, 0, 0, 0}, prim_name); | |||||
| } | |||||
| this->SetPad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||||
| this->SetMode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name)); | |||||
| this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); | |||||
| this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||||
| } | |||||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_CONV2D_H | |||||
| #define MINDSPORE_CORE_C_OPS_CONV2D_H | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| class Conv2d : public PrimitiveC { | |||||
| public: | |||||
| Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } | |||||
| void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", | |||||
| const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, | |||||
| const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); | |||||
| std::vector<int> GetKernelSize() const { | |||||
| auto value_ptr = this->GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int>>(value_ptr); | |||||
| } | |||||
| std::vector<int> GetStride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<std::vector<int>>(value_ptr); | |||||
| } | |||||
| std::vector<int> GetDilation() const { | |||||
| auto value_ptr = GetAttr(kDilation); | |||||
| return GetValue<std::vector<int>>(value_ptr); | |||||
| } | |||||
| std::string GetPadMode() const { | |||||
| auto value_ptr = this->GetAttr(kPadMode); | |||||
| return GetValue<string>(value_ptr); | |||||
| } | |||||
| std::vector<int> GetPad() const { | |||||
| auto value_ptr = this->GetAttr(kPad); | |||||
| return GetValue<std::vector<int>>(value_ptr); | |||||
| } | |||||
| int GetMode() const { | |||||
| auto value_ptr = this->GetAttr(kMode); | |||||
| return GetValue<int>(value_ptr); | |||||
| } | |||||
| int GetGroup() const { | |||||
| auto value_ptr = this->GetAttr(kGroup); | |||||
| return GetValue<int>(value_ptr); | |||||
| } | |||||
| int GetOutputChannel() const { | |||||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||||
| return GetValue<int>(value_ptr); | |||||
| } | |||||
| void SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } | |||||
| void SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||||
| void SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||||
| void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||||
| void SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||||
| void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||||
| void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } | |||||
| void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||||
| void SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||||
| private: | |||||
| inline static const string kKernelSize = "kernel_size"; | |||||
| inline static const string kStride = "stride"; | |||||
| inline static const string kDilation = "dilation"; | |||||
| inline static const string kPadMode = "pad_mode"; | |||||
| inline static const string kPad = "pad"; | |||||
| inline static const string kMode = "mode"; | |||||
| inline static const string kGroup = "group"; | |||||
| inline static const string kOutputChannel = "output channel"; | |||||
| inline static const string kPadList = "pad_list"; | |||||
| inline static const string kConv2DName = "Conv2D"; | |||||
| }; | |||||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_CONV2D_H | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H | |||||
| #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ir/primitive.h" | |||||
| #include "ir/value.h" | |||||
| namespace mindspore { | |||||
| class PrimitiveC : public Primitive { | |||||
| public: | |||||
| explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; } | |||||
| protected: | |||||
| void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) { | |||||
| this->AddAttr("input_names", MakeValue(inputs_name)); | |||||
| this->AddAttr("output_names", MakeValue(outputs_name)); | |||||
| } | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H | |||||
| @@ -632,6 +632,19 @@ void FuncGraph::CheckOrder() { | |||||
| MS_LOG(DEBUG) << "Check order okay."; | MS_LOG(DEBUG) << "Check order okay."; | ||||
| } | } | ||||
| } | } | ||||
| CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive_node = std::make_shared<ValueNode>(primitive); | |||||
| std::vector<AnfNodePtr> input_node_list = {primitive_node}; | |||||
| std::copy(inputs.begin(), inputs.end(), std::back_inserter(input_node_list)); | |||||
| return NewCNode(input_node_list); | |||||
| } | |||||
| ParameterPtr FuncGraph::add_parameter(const tensor::MetaTensorPtr &meta_tensor) { | |||||
| auto parameter = add_parameter(); | |||||
| parameter->set_default_param(MakeValue(meta_tensor)); | |||||
| parameter->set_abstract(meta_tensor->ToAbstract()); | |||||
| return parameter; | |||||
| } | |||||
| size_t NewFgSeenGeneration() { | size_t NewFgSeenGeneration() { | ||||
| static size_t fg_seen_generation = 0; | static size_t fg_seen_generation = 0; | ||||
| @@ -170,7 +170,9 @@ class FuncGraph : public FuncGraphBase { | |||||
| // create a cnode with given inputs, bound to this graph, and set to specific scope | // create a cnode with given inputs, bound to this graph, and set to specific scope | ||||
| CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope); | CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope); | ||||
| virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs); | |||||
| virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor); | |||||
| // Functions for handling variable argument, keyword-only arguments and variable keyword argument | // Functions for handling variable argument, keyword-only arguments and variable keyword argument | ||||
| AnfNodePtr GetDefaultValueByName(const std::string &name); | AnfNodePtr GetDefaultValueByName(const std::string &name); | ||||
| void set_param_default_value(const std::string &name, const AnfNodePtr &node) { | void set_param_default_value(const std::string &name, const AnfNodePtr &node) { | ||||
| @@ -0,0 +1,270 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include <utility> | |||||
| #include "abstract/abstract_value.h" | |||||
| namespace mindspore { | |||||
| namespace { | |||||
| const std::map<CompareEnum, std::function<bool(int, int)>> kCompareMap = { | |||||
| {kEqual, [](int num1, int num2) -> bool { return num1 == num2; }}, | |||||
| {kNotEqual, [](int num1, int num2) -> bool { return num1 != num2; }}, | |||||
| {kLessThan, [](int num1, int num2) -> bool { return num1 < num2; }}, | |||||
| {kLessEqual, [](int num1, int num2) -> bool { return num1 <= num2; }}, | |||||
| {kGreaterThan, [](int num1, int num2) -> bool { return num1 > num2; }}, | |||||
| {kGreaterEqual, [](int num1, int num2) -> bool { return num1 >= num2; }}}; | |||||
| const std::map<CompareRange, std::function<bool(int, std::pair<int, int>)>> kCompareRangeMap = { | |||||
| {kIncludeNeither, | |||||
| [](int num1, std::pair<int, int> range) -> bool { return num1 > range.first && num1 < range.second; }}, | |||||
| {kIncludeLeft, | |||||
| [](int num1, std::pair<int, int> range) -> bool { return num1 >= range.first && num1 < range.second; }}, | |||||
| {kIncludeRight, | |||||
| [](int num1, std::pair<int, int> range) -> bool { return num1 > range.first && num1 <= range.second; }}, | |||||
| {kIncludeBoth, | |||||
| [](int num1, std::pair<int, int> range) -> bool { return num1 >= range.first && num1 <= range.second; }}}; | |||||
| const std::map<CompareEnum, std::string> kCompareToString = { | |||||
| {kEqual, "equal"}, {kNotEqual, "not equal"}, {kLessThan, "less than"}, | |||||
| {kLessEqual, "less eqaul"}, {kGreaterThan, "greater than"}, {kGreaterEqual, "greate equal"}}; | |||||
| const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeToString = { | |||||
| {kIncludeNeither, {"in (", ")"}}, | |||||
| {kIncludeLeft, {" in [", ")"}}, | |||||
| {kIncludeRight, {"in (", "]"}}, | |||||
| {kIncludeBoth, {"in [", "]"}}}; | |||||
| } // namespace | |||||
| bool CheckAndConvertUtils::IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2) { | |||||
| if (vec_1.size() != vec_2.size()) { | |||||
| return false; | |||||
| } | |||||
| for (size_t index = 0; index < vec_1.size(); ++index) { | |||||
| if (vec_1[index] != vec_2[index]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::vector<int> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name, | |||||
| const std::vector<int> &arg_value, | |||||
| const std::string &prim_name, bool allow_four, | |||||
| bool ret_four) { | |||||
| if (arg_value.size() == 2) { | |||||
| return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value; | |||||
| } else if (arg_value.size() == 4 && allow_four) { | |||||
| return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]}; | |||||
| } | |||||
| std::ostringstream buffer; | |||||
| buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; | |||||
| if (allow_four) { | |||||
| buffer << "or four "; | |||||
| } | |||||
| buffer << " positive int numbers , but got ["; | |||||
| for (auto item : arg_value) { | |||||
| buffer << item << ","; | |||||
| } | |||||
| buffer << "]"; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| } | |||||
| std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value, | |||||
| const std::set<std::string> &check_list, const std::string &prim_name) { | |||||
| if (check_list.find(arg_value) != check_list.end()) { | |||||
| return arg_value; | |||||
| } | |||||
| std::ostringstream buffer; | |||||
| buffer << "For " << prim_name << " the " << arg_name << " should be str and must be "; | |||||
| if (check_list.size() == 1) { | |||||
| buffer << (*check_list.begin()) << "but got " << arg_value; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| } | |||||
| buffer << "one of {"; | |||||
| for (const auto &item : check_list) { | |||||
| buffer << item << " ,"; | |||||
| } | |||||
| buffer << " }" | |||||
| << " but got " << arg_value; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| } | |||||
| int CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int arg_value, CompareEnum compare_operator, | |||||
| int match_value, const std::string &prim_name) { | |||||
| auto iter = kCompareMap.find(compare_operator); | |||||
| if (iter == kCompareMap.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; | |||||
| } | |||||
| if (iter->second(arg_value, match_value)) { | |||||
| return arg_value; | |||||
| } | |||||
| std::ostringstream buffer; | |||||
| if (prim_name.empty()) { | |||||
| buffer << "The "; | |||||
| } else { | |||||
| buffer << "For " << prim_name << " the "; | |||||
| } | |||||
| buffer << arg_name << " must "; | |||||
| auto iter_to_string = kCompareToString.find(compare_operator); | |||||
| if (iter_to_string == kCompareToString.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map"; | |||||
| } | |||||
| buffer << iter_to_string->second << match_value << " , but got " << arg_value; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| } | |||||
| void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_value, CompareRange compare_operator, | |||||
| const std::pair<int, int> &range, const std::string &prim_name) { | |||||
| auto iter = kCompareRangeMap.find(compare_operator); | |||||
| if (iter == kCompareRangeMap.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; | |||||
| } | |||||
| if (iter->second(arg_value, range)) { | |||||
| return; | |||||
| } | |||||
| std::ostringstream buffer; | |||||
| if (prim_name.empty()) { | |||||
| buffer << "The "; | |||||
| } else { | |||||
| buffer << "For " << prim_name << " the "; | |||||
| } | |||||
| buffer << arg_name << " must "; | |||||
| auto iter_to_string = kCompareRangeToString.find(compare_operator); | |||||
| if (iter_to_string == kCompareRangeToString.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map"; | |||||
| } | |||||
| auto range_strng = iter_to_string->second; | |||||
| buffer << range_strng.first << range.first << "," << range_strng.second << " , but got " << arg_value; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| } | |||||
| std::vector<int> CheckAndConvertUtils::ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, | |||||
| const std::string &prim_name) { | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (!shape->isa<abstract::Shape>()) { | |||||
| MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << shape->ToString() | |||||
| << "should be a common shape!"; | |||||
| } | |||||
| auto shape_element = shape->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape_element); | |||||
| return shape_element->shape(); | |||||
| } | |||||
| TypeId CheckAndConvertUtils::ConvertTypePtrToTypeId(const string &arg_name, const TypePtr &type_ptr, | |||||
| const string &prim_name) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| if (!type_ptr->isa<TensorType>() || !type_ptr->isa<Number>()) { | |||||
| MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << type_ptr->ToString() | |||||
| << "should be a common type!(tensor_type && numbertype)"; | |||||
| } | |||||
| return type_ptr->type_id(); | |||||
| } | |||||
| void CheckAndConvertUtils::Check(const string &arg_name, int arg_value, CompareEnum compare_type, | |||||
| const string &value_name, int value, const string &prim_name, | |||||
| ExceptionType exception_type) { | |||||
| auto iter = kCompareMap.find(compare_type); | |||||
| if (iter == kCompareMap.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map"; | |||||
| } | |||||
| if (iter->second(arg_value, value)) { | |||||
| return; | |||||
| } | |||||
| std::ostringstream buffer; | |||||
| if (prim_name.empty()) { | |||||
| buffer << "The "; | |||||
| } else { | |||||
| buffer << "For " << prim_name << " the "; | |||||
| } | |||||
| auto iter_to_string = kCompareToString.find(compare_type); | |||||
| if (iter_to_string == kCompareToString.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; | |||||
| } | |||||
| MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value | |||||
| << " but got " << arg_value; | |||||
| } | |||||
| void CheckAndConvertUtils::Check(const string &arg_name, const std::vector<int> &arg_value, CompareEnum compare_type, | |||||
| const string &value_name, const std::vector<int> &value, const string &prim_name, | |||||
| ExceptionType exception_type) { | |||||
| if (compare_type != kEqual) { | |||||
| auto iter = kCompareToString.find(compare_type); | |||||
| if (iter != kCompareToString.end()) { | |||||
| MS_EXCEPTION(NotSupportError) << "Only supported equal to compare two vectors but got " << iter->second; | |||||
| } | |||||
| MS_EXCEPTION(UnknownError) << "Cannot find the operator " << compare_type << "in the compare map!"; | |||||
| } | |||||
| if (arg_value == value) { | |||||
| return; | |||||
| } | |||||
| std::ostringstream buffer; | |||||
| if (prim_name.empty()) { | |||||
| buffer << "The "; | |||||
| } else { | |||||
| buffer << "For " << prim_name << " the "; | |||||
| } | |||||
| auto iter_to_string = kCompareToString.find(compare_type); | |||||
| if (iter_to_string == kCompareToString.end()) { | |||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; | |||||
| } | |||||
| buffer << arg_name << "should be " << iter_to_string->second << " ["; | |||||
| for (auto item : value) { | |||||
| buffer << item << ","; | |||||
| } | |||||
| buffer << "] " | |||||
| << "but got ["; | |||||
| for (auto item : arg_value) { | |||||
| buffer << item << " ,"; | |||||
| } | |||||
| buffer << "]"; | |||||
| MS_EXCEPTION(exception_type) << buffer.str(); | |||||
| } | |||||
| void CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, | |||||
| const std::set<TypeId> &check_list, const std::string &prim_name) { | |||||
| if (types.empty()) { | |||||
| MS_LOG(WARNING) << "Tryinh to use the function to check a empty types map!"; | |||||
| return; | |||||
| } | |||||
| std::set<TypeId> types_id; | |||||
| std::ostringstream buffer; | |||||
| buffer << "For " << prim_name; | |||||
| for (const auto &type : types) { | |||||
| MS_EXCEPTION_IF_NULL(type.second); | |||||
| if (!type.second->isa<TensorType>()) { | |||||
| MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << type.first << " input must be tensor type but got " | |||||
| << type.second->ToString(); | |||||
| } | |||||
| types_id.emplace(type.second->type_id()); | |||||
| } | |||||
| if (types_id.size() > 1) { | |||||
| buffer << "'s input type is not same : "; | |||||
| for (const auto &item : types) { | |||||
| buffer << "[ name : " << item.first << " ,type : " << item.second->ToString() << "]"; | |||||
| } | |||||
| MS_EXCEPTION(TypeError) << buffer.str(); | |||||
| } | |||||
| if (check_list.find(*(types_id.begin())) != check_list.end()) { | |||||
| buffer << " type of "; | |||||
| for (const auto &elem : types) { | |||||
| buffer << elem.first << " should be in ["; | |||||
| for (auto type_elem : check_list) { | |||||
| buffer << type_elem << " ,"; | |||||
| } | |||||
| buffer << "] , but got " << types.begin()->second->ToString(); | |||||
| } | |||||
| } | |||||
| MS_EXCEPTION(TypeError) << buffer.str(); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H | |||||
| #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include "base/base.h" | |||||
| #include "ir/anf.h" | |||||
| #include "ir/dtype/type_id.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| enum CompareEnum : int { | |||||
| kEqual = 1, // == | |||||
| kNotEqual = 2, // != | |||||
| kLessThan = 3, // < | |||||
| kLessEqual = 4, // <= | |||||
| kGreaterThan = 5, // > | |||||
| kGreaterEqual = 6, // >= | |||||
| }; | |||||
| enum CompareRange { | |||||
| kIncludeNeither = 1, // (a,b) | |||||
| kIncludeLeft = 2, // [a,b) | |||||
| kIncludeRight = 3, // (a,b] | |||||
| kIncludeBoth = 4, // [a,b] | |||||
| }; | |||||
| class CheckAndConvertUtils { | |||||
| public: | |||||
| static std::vector<int> CheckPositiveVector(const std::string &arg_name, const std::vector<int> &arg_value, | |||||
| const std::string &prim_name, bool allow_four = false, | |||||
| bool ret_four = false); | |||||
| static std::string CheckString(const std::string &arg_name, const std::string &arg_value, | |||||
| const std::set<std::string> &check_list, const std::string &prim_name); | |||||
| static int CheckInteger(const std::string &arg_name, int arg_value, CompareEnum compare_operator, int match_value, | |||||
| const std::string &prim_name); | |||||
| static void CheckInRange(const std::string &arg_name, int arg_value, CompareRange compare_operator, | |||||
| const std::pair<int, int> &range, const std::string &prim_name); | |||||
| static std::vector<int> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, | |||||
| const std::string &prim_name); | |||||
| static TypeId ConvertTypePtrToTypeId(const std::string &arg_name, const TypePtr &type_ptr, | |||||
| const std::string &prim_name); | |||||
| static void Check(const std::string &arg_name, int arg_value, CompareEnum compare_type, const std::string &value_name, | |||||
| int value, const std::string &prim_name = "", ExceptionType exception_type = ValueError); | |||||
| static void Check(const std::string &arg_name, const std::vector<int> &arg_value, CompareEnum compare_type, | |||||
| const std::string &value_name, const std::vector<int> &value, const std::string &prim_name = "", | |||||
| ExceptionType exception_type = ValueError); | |||||
| static void CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypeId> &check_list, | |||||
| const std::string &prim_name); | |||||
| private: | |||||
| static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2); | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H | |||||