From: @wangnan39 Reviewed-by: @kingxian Signed-off-by:tags/v1.2.0-rc1
| @@ -18,6 +18,7 @@ | |||
| #include <set> | |||
| #include "common/trans.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "utils/utils.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| @@ -66,9 +67,17 @@ void SetTransNodeAttr(const CNodePtr &trans_node) { | |||
| std::string InitDefaultFormat(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) { | |||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| return kOpFormat_NCDHW; | |||
| auto primitive_ptr = GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(primitive_ptr); | |||
| auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (!result) { | |||
| auto attr = GetValue<std::string>(data_format_ptr); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| return kOpFormat_NCDHW; | |||
| } | |||
| } | |||
| } else if (AnfAlgo::IsRealKernel(node)) { | |||
| auto formats = AnfAlgo::GetAllOutputFormats(node); | |||
| @@ -23,6 +23,7 @@ | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| @@ -46,9 +47,15 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto | |||
| if (group == 1) { | |||
| return false; | |||
| } | |||
| auto data_format = AnfAlgo::GetNodeAttr<std::string>(conv2d, kAttrFormat); | |||
| if (data_format != "NCHW") { | |||
| MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format; | |||
| auto primitive_ptr = GetCNodePrimitive(conv2d); | |||
| MS_EXCEPTION_IF_NULL(primitive_ptr); | |||
| auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (!result || data_format != Format::NCHW) { | |||
| MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1"; | |||
| } | |||
| if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size() | |||
| @@ -21,6 +21,7 @@ | |||
| #include "base/core_ops.h" | |||
| #include "ir/param_info.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| @@ -402,9 +403,17 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| } | |||
| SetKernelInfoForNode(cnode); | |||
| if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) { | |||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||
| auto primitive_ptr = GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive_ptr); | |||
| auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (!result) { | |||
| auto attr = GetValue<std::string>(data_format_ptr); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||
| } | |||
| } | |||
| } | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| @@ -29,6 +29,7 @@ | |||
| #include "utils/convert_utils_py.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/primitive_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "pipeline/pynative/pynative_execute.h" | |||
| @@ -280,6 +281,8 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||
| if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { | |||
| attr_name = kOpAttrNameReplaceMap[attr_name]; | |||
| } | |||
| const std::string &prim_name = this->name(); | |||
| CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret); | |||
| (void)this->AddAttr(attr_name, converted_ret); | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "base/core_ops.h" | |||
| #include "proto/mind_ir.pb.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| using FloatPtr = std::shared_ptr<Float>; | |||
| @@ -425,7 +426,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons | |||
| MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); | |||
| mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name(attr.first); | |||
| SetValueToAttributeProto(attr.second, attr_proto); | |||
| auto attr_value = attr.second; | |||
| CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value); | |||
| SetValueToAttributeProto(attr_value, attr_proto); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); | |||
| @@ -25,6 +25,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "base/core_ops.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| enum OpMergeMode { | |||
| @@ -102,8 +103,9 @@ void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_Attrib | |||
| void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, | |||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | |||
| auto attr_value = GetValue<std::string>(value); | |||
| if (attr_value == "VALID") { | |||
| int64_t attr_value; | |||
| CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value, true); | |||
| if (attr_value == PadMode::VALID) { | |||
| attr_proto->set_s("VALID"); | |||
| } else { | |||
| attr_proto->set_s("SAME_UPPER"); | |||
| @@ -186,10 +188,11 @@ OPERATOR_ONNX_CONVERT_DEFINE( | |||
| [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, | |||
| const PrimitivePtr &prim) { | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | |||
| auto attr_value = GetValue<std::string>(value); | |||
| if (attr_value == "valid") { | |||
| int64_t attr_value; | |||
| CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value); | |||
| if (attr_value == PadMode::VALID) { | |||
| attr_proto->set_s("VALID"); | |||
| } else if (attr_value == "same") { | |||
| } else if (attr_value == PadMode::SAME) { | |||
| attr_proto->set_s("SAME_UPPER"); | |||
| } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' | |||
| attr_proto->set_name("pads"); | |||
| @@ -834,12 +837,13 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/ | |||
| // set pad | |||
| onnx_attr_proto = node_proto->add_attribute(); | |||
| auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||
| int64_t attr_value; | |||
| CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr("pad_mode"), &attr_value); | |||
| onnx_attr_proto->set_name("auto_pad"); | |||
| onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | |||
| if (attr_value == "valid") { | |||
| if (attr_value == PadMode::VALID) { | |||
| onnx_attr_proto->set_s("VALID"); | |||
| } else if (attr_value == "same") { | |||
| } else if (attr_value == PadMode::SAME) { | |||
| onnx_attr_proto->set_s("SAME_UPPER"); | |||
| } else { | |||
| onnx_attr_proto->set_name("pads"); | |||
| @@ -59,20 +59,16 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; | |||
| } | |||
| std::set<std::string> available_pad_mode{"pad", "same", "valid"}; | |||
| auto pad_mode_ptr = primitive->GetAttr("pad_mode"); | |||
| if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa<StringImm>()) { | |||
| auto pad_mode = pad_mode_ptr->cast<StringImmPtr>()->value(); | |||
| if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { | |||
| MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; | |||
| } | |||
| if (pad_mode == "valid") { | |||
| if (pad_mode_ptr != nullptr) { | |||
| int64_t pad_mode; | |||
| CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true); | |||
| if (pad_mode == PadMode::VALID) { | |||
| padding = 0; | |||
| } else if (pad_mode == "same") { | |||
| } else if (pad_mode == PadMode::SAME) { | |||
| padding = (window - 1) / 2; | |||
| } | |||
| } | |||
| std::set<std::string> available_mode{"max", "avg"}; | |||
| auto mode_ptr = primitive->GetAttr("mode"); | |||
| if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) { | |||
| @@ -270,13 +266,13 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit | |||
| void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h, | |||
| const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride, | |||
| const std::vector<int64_t> &dilation, const std::string &pad_mode, | |||
| const std::vector<int64_t> &dilation, const int64_t &pad_mode, | |||
| const std::vector<int64_t> &padding) { | |||
| if (pad_mode == "valid") { | |||
| if (pad_mode == PadMode::VALID) { | |||
| output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); | |||
| output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); | |||
| pad_list->insert(pad_list->begin(), 4, 0); | |||
| } else if (pad_mode == "same") { | |||
| } else if (pad_mode == PadMode::SAME) { | |||
| output_hw->push_back(std::ceil((x_h * 1.0) / stride[0])); | |||
| output_hw->push_back(std::ceil((x_w * 1.0) / stride[1])); | |||
| int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; | |||
| @@ -287,7 +283,7 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa | |||
| pad_needed_w = std::max((int64_t)0, pad_needed_w); | |||
| pad_list->push_back(std::floor(pad_needed_w / 2)); | |||
| pad_list->push_back(pad_needed_w - pad_list->at(2)); | |||
| } else if (pad_mode == "pad") { | |||
| } else if (pad_mode == PadMode::PAD) { | |||
| pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); | |||
| output_hw->push_back(std::floor( | |||
| 1 + | |||
| @@ -298,6 +294,15 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa | |||
| } | |||
| } | |||
| int64_t GetAndCheckFormat(const ValuePtr &value) { | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | |||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { | |||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; | |||
| } | |||
| return data_format; | |||
| } | |||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| @@ -322,12 +327,12 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); | |||
| CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); | |||
| CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); | |||
| std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"}); | |||
| int64_t n_axis = 0; | |||
| int64_t c_axis = 1; | |||
| int64_t h_axis = 2; | |||
| int64_t w_axis = 3; | |||
| if (data_format == "NHWC") { | |||
| int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format")); | |||
| if (data_format == Format::NHWC) { | |||
| c_axis = 3; | |||
| h_axis = 1; | |||
| w_axis = 2; | |||
| @@ -352,8 +357,8 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| std::vector<int64_t> stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2); | |||
| std::vector<int64_t> dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2); | |||
| std::vector<int64_t> padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4); | |||
| std::string pad_mode = | |||
| CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"}); | |||
| int64_t pad_mode; | |||
| CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode); | |||
| std::vector<int64_t> output_hw; | |||
| std::vector<int64_t> pad_list; | |||
| std::vector<int64_t> output_hw_min; | |||
| @@ -378,7 +383,7 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| ShapeVector output_shape; | |||
| ShapeVector output_shape_min; | |||
| ShapeVector output_shape_max; | |||
| if (data_format == "NHWC") { | |||
| if (data_format == Format::NHWC) { | |||
| output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; | |||
| output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; | |||
| output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; | |||
| @@ -426,16 +431,12 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| ShapeVector bias_shape = bias->shape()->shape(); | |||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||
| std::set<std::string> available_data_format{"NCHW", "NHWC"}; | |||
| auto data_format_ptr = primitive->GetAttr("format"); | |||
| std::string data_format = "NCHW"; | |||
| if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { | |||
| data_format = data_format_ptr->cast<StringImmPtr>()->value(); | |||
| } | |||
| if (available_data_format.find(data_format) == available_data_format.end()) { | |||
| MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC."; | |||
| int64_t data_format = Format::NCHW; | |||
| if (data_format_ptr != nullptr) { | |||
| data_format = GetAndCheckFormat(data_format_ptr); | |||
| } | |||
| auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1]; | |||
| auto x_channel = data_format == Format::NHWC ? x_shape[x_shape.size() - 1] : x_shape[1]; | |||
| // Additional check for dynamic shape | |||
| // Last infer will be real shape values | |||
| bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||
| @@ -29,6 +29,7 @@ | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| using std::string; | |||
| @@ -494,7 +495,11 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind | |||
| case FORM_PARSE_SCALAR: { | |||
| std::size_t value_pos(0); | |||
| if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { | |||
| auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); | |||
| ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto); | |||
| const std::string &op_type = prim->name(); | |||
| if (!IsLite()) { | |||
| CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res); | |||
| } | |||
| prim->AddAttr(attr_name, res); | |||
| break; | |||
| } | |||
| @@ -39,6 +39,8 @@ class MSANFModelParser { | |||
| std::string GetProducerName() { return producer_name_; } | |||
| std::string GetProducerVersion() { return model_version_; } | |||
| std::string GetIrVersion() { return ir_version_; } | |||
| void SetLite() { is_lite_ = true; } | |||
| bool IsLite() { return is_lite_; } | |||
| private: | |||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); | |||
| @@ -68,6 +70,7 @@ class MSANFModelParser { | |||
| std::string producer_name_; | |||
| std::string model_version_; | |||
| std::string ir_version_; | |||
| bool is_lite_ = false; | |||
| std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -71,7 +71,7 @@ std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) { | |||
| return buf; | |||
| } | |||
| std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) { | |||
| std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite) { | |||
| auto graphBuf = ReadProtoFile(file_name); | |||
| if (graphBuf == nullptr) { | |||
| MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str(); | |||
| @@ -79,7 +79,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) { | |||
| } | |||
| try { | |||
| auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size()); | |||
| auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size(), is_lite); | |||
| return graph; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| @@ -87,7 +87,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) { | |||
| } | |||
| } | |||
| std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) { | |||
| std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) { | |||
| MS_EXCEPTION_IF_NULL(buf); | |||
| std::string str((const char *)buf, buf_size); | |||
| mind_ir::ModelProto model_; | |||
| @@ -95,6 +95,9 @@ std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_ | |||
| MS_LOG(ERROR) << "Parse model from buffer fail!"; | |||
| } | |||
| MSANFModelParser model_parser; | |||
| if (is_lite) { | |||
| model_parser.SetLite(); | |||
| } | |||
| FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); | |||
| return dstgraph_ptr; | |||
| } | |||
| @@ -24,8 +24,8 @@ | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore { | |||
| std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name); | |||
| std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false); | |||
| std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file); | |||
| std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size); | |||
| std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_LOAD_MODEL_H | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,13 +14,16 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/check_convert_utils.h" | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <typeinfo> | |||
| #include <functional> | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "ir/dtype/tensor_type.h" | |||
| #include "ir/dtype.h" | |||
| @@ -84,21 +87,21 @@ AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMa | |||
| AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap); | |||
| static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = { | |||
| {"format", DataFormatConverter}, | |||
| {"pad_mode", PadModeConverter}, | |||
| {ops::kFormat, DataFormatConverter}, | |||
| {ops::kPadMode, PadModeConverter}, | |||
| }; | |||
| static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = { | |||
| {"format", DataFormatConverter}, | |||
| {"pad_mode", PadModeUpperConverter}, | |||
| {ops::kFormat, DataFormatConverter}, | |||
| {ops::kPadMode, PadModeUpperConverter}, | |||
| }; | |||
| static std::map<std::string, AttrConverterPair> DataFormatMap = { | |||
| {"format", DataFormatConverter}, | |||
| {ops::kFormat, DataFormatConverter}, | |||
| }; | |||
| static std::map<std::string, AttrConverterPair> ReductionMap = { | |||
| {"reduction", ReductionConverter}, | |||
| {ops::kReduction, ReductionConverter}, | |||
| }; | |||
| static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = { | |||
| @@ -132,24 +135,42 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC | |||
| {"BinaryCrossEntropy", ReductionMap}, | |||
| {"BinaryCrossEntropyGrad", ReductionMap}, | |||
| {"NLLLoss", ReductionMap}, | |||
| {"DepthToSpace", DataFormatMap}, | |||
| }; | |||
| int64_t CheckAndConvertUtils::GetDataFormatEnumValue(const std::string &value) { | |||
| if (DataFormatToEnumMap.find(value) == DataFormatToEnumMap.end()) { | |||
| MS_LOG(ERROR) << "Can not convert data format " << value << "to enum"; | |||
| bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<StringImm>()) { | |||
| auto attr_value_str = GetValue<std::string>(value); | |||
| if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) { | |||
| MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum."; | |||
| return false; | |||
| } | |||
| *enum_value = DataFormatToEnumMap[attr_value_str]; | |||
| return true; | |||
| } else { | |||
| *enum_value = GetValue<int64_t>(value); | |||
| return true; | |||
| } | |||
| return DataFormatToEnumMap[value]; | |||
| return false; | |||
| } | |||
| int64_t CheckAndConvertUtils::GetPadModEnumValue(const std::string &value, bool is_upper) { | |||
| std::map<std::string, int64_t> pad_map = PadModToEnumMap; | |||
| if (is_upper) { | |||
| pad_map = PadModToEnumUpperMap; | |||
| } | |||
| if (pad_map.find(value) == pad_map.end()) { | |||
| MS_LOG(ERROR) << "Can not convert pad mode " << value << "to enum"; | |||
| void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<StringImm>()) { | |||
| auto attr_value_str = GetValue<std::string>(value); | |||
| std::map<std::string, int64_t> pad_map = PadModToEnumMap; | |||
| if (is_upper) { | |||
| pad_map = PadModToEnumUpperMap; | |||
| } | |||
| if (pad_map.find(attr_value_str) == pad_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same"; | |||
| } | |||
| *enum_value = pad_map[attr_value_str]; | |||
| } else { | |||
| *enum_value = GetValue<int64_t>(value); | |||
| } | |||
| return pad_map[value]; | |||
| } | |||
| AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) { | |||
| @@ -172,8 +193,8 @@ AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op | |||
| bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, | |||
| ValuePtr *const value) { | |||
| if (value == nullptr) { | |||
| MS_LOG(ERROR) << "value is nullptr"; | |||
| if (value == nullptr || *value == nullptr) { | |||
| MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr."; | |||
| return false; | |||
| } | |||
| if (!(*value)->isa<StringImm>()) { | |||
| @@ -191,12 +212,17 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con | |||
| } | |||
| if (!do_convert) { | |||
| transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper); | |||
| if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) { | |||
| do_convert = true; | |||
| } | |||
| } | |||
| if (!do_convert) { | |||
| transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower); | |||
| if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) { | |||
| MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int"; | |||
| return false; | |||
| } | |||
| } | |||
| *value = MakeValue<int64_t>(attr_map_pair.first[real_value]); | |||
| MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name; | |||
| return true; | |||
| @@ -204,7 +230,7 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con | |||
| bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, | |||
| ValuePtr *const value) { | |||
| if (value == nullptr) { | |||
| if (value == nullptr || *value == nullptr) { | |||
| MS_LOG(ERROR) << "value is nullptr"; | |||
| return false; | |||
| } | |||
| @@ -226,7 +252,6 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, | |||
| return true; | |||
| } | |||
| namespace { | |||
| typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction; | |||
| @@ -242,7 +267,6 @@ std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis", | |||
| {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}}; | |||
| } // namespace | |||
| bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2) { | |||
| if (vec_1.size() != vec_2.size()) { | |||
| return false; | |||
| @@ -284,8 +284,8 @@ class CheckAndConvertUtils { | |||
| static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | |||
| static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | |||
| static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); | |||
| static int64_t GetDataFormatEnumValue(const std::string &value); | |||
| static int64_t GetPadModEnumValue(const std::string &value, bool is_upper = false); | |||
| static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); | |||
| static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); | |||
| static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | |||
| private: | |||
| @@ -856,7 +856,7 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model | |||
| int AnfImporterFromMindir::Import(const converter::Flags *flag) { | |||
| #if SUPPORT_TRAIN | |||
| func_graph_ = LoadMindIR(flag->modelFile); | |||
| func_graph_ = LoadMindIR(flag->modelFile, true); | |||
| if (func_graph_ != nullptr) { | |||
| return RET_OK; | |||
| } else { | |||
| @@ -866,7 +866,7 @@ int AnfImporterFromMindir::Import(const converter::Flags *flag) { | |||
| onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | |||
| if (onnx_model_ == nullptr) { | |||
| MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | |||
| func_graph_ = LoadMindIR(flag->modelFile); | |||
| func_graph_ = LoadMindIR(flag->modelFile, true); | |||
| if (func_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; | |||
| return RET_GRAPH_FILE_ERR; | |||