| @@ -699,7 +699,6 @@ build_lite() | |||
| if [[ "X$COMPILE_LITE" = "Xon" ]]; then | |||
| build_lite | |||
| exit | |||
| else | |||
| build_mindspore | |||
| fi | |||
| @@ -202,6 +202,7 @@ union PrimitiveType { | |||
| NegGrad, | |||
| LogGrad, | |||
| BatchToSpaceND, | |||
| LshProjection, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -120,6 +120,12 @@ enum PaddingMode : byte { | |||
| MODE_RESERVED = 3 | |||
| } | |||
| enum LshProjectionType : byte { | |||
| UNKNOWN = 0, | |||
| SPARSE = 1, | |||
| DENSE = 2 | |||
| } | |||
| table Pad { | |||
| paddings: [int]; | |||
| paddingMode: PaddingMode; | |||
| @@ -661,7 +667,8 @@ enum ReduceMode : byte { | |||
| ReduceMin = 2, | |||
| ReduceProd = 3, | |||
| ReduceSum = 4, | |||
| ReduceSumSquare = 5 | |||
| ReduceSumSquare = 5, | |||
| ReduceASum = 6 | |||
| } | |||
| table Reduce { | |||
| @@ -785,7 +792,7 @@ table FloorMod { | |||
| table L2Norm { | |||
| axis: [int]; | |||
| epsilon: float; | |||
| activationType: ActivationType; | |||
| activationType: ActivationType = 0; | |||
| } | |||
| table LogicalAnd { | |||
| @@ -937,3 +944,7 @@ table BlackBox { | |||
| size : int; | |||
| address : [ubyte]; | |||
| } | |||
| table LshProjection { | |||
| type : LshProjectionType; | |||
| } | |||
| @@ -106,7 +106,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| if (i >= dst_node->inputIndex.size()) { | |||
| MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() | |||
| << " quant_params; but only " << dst_node->inputIndex.size() << " input"; | |||
| break; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto activate_index = dst_node->inputIndex[i]; | |||
| auto tensor_input = meta_graph->allTensors[activate_index].get(); | |||
| @@ -170,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & | |||
| } | |||
| } | |||
| void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *return_node) { | |||
| MS_ASSERT(nullptr != meta_graph); | |||
| MS_ASSERT(nullptr != return_node); | |||
| @@ -178,31 +178,34 @@ void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_p | |||
| auto input_node = cnode->input(i); | |||
| if (input_node == nullptr) { | |||
| MS_LOG(ERROR) << "output node is nullptr"; | |||
| return; | |||
| return RET_NULL_PTR; | |||
| } else if (input_node->isa<CNode>()) { | |||
| auto ret = ConvertInputCNode(input_node, return_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "obtain outputs failed"; | |||
| return; | |||
| return ret; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; | |||
| return; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < return_node->inputIndex.size(); ++i) { | |||
| meta_graphT->outputIndex.push_back(return_node->inputIndex[i]); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph) { | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||
| int ret = RET_OK; | |||
| for (const auto &cnode : cnodes) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return nullptr; | |||
| ret = RET_MEMORY_FAILED; | |||
| break; | |||
| } | |||
| if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || | |||
| primitive_c->Type() == schema::PrimitiveType_MakeTuple || | |||
| @@ -216,32 +219,41 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||
| auto node = std::make_unique<schema::CNodeT>(); | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "object failed to be constructed"; | |||
| return nullptr; | |||
| ret = RET_MEMORY_FAILED; | |||
| break; | |||
| } | |||
| if (primT->value.type == schema::PrimitiveType_Return) { | |||
| node->name = "return_node"; | |||
| SetGraphoutputIndex(cnode, meta_graphT, node.get()); | |||
| ret = SetGraphoutputIndex(cnode, meta_graphT, node.get()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetOpOutputN failed"; | |||
| break; | |||
| } | |||
| continue; | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->name = cnode->fullname_with_scope(); | |||
| node->primitive = std::unique_ptr<schema::PrimitiveT>(primT); | |||
| auto ret = SetOpInputNode(cnode, meta_graphT, node.get()); | |||
| ret = SetOpInputNode(cnode, meta_graphT, node.get()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetOpInputNode failed"; | |||
| return nullptr; | |||
| break; | |||
| } | |||
| SetOpOutputNode(cnode, meta_graphT, node.get()); | |||
| ret = ConvertQuantParam(meta_graphT, primitive_c, node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertQuantParam failed"; | |||
| return nullptr; | |||
| break; | |||
| } | |||
| if (!keep_graph) { | |||
| primitive_c->ClearPrimitiveT(); | |||
| } | |||
| meta_graphT->nodes.emplace_back(std::move(node)); | |||
| } | |||
| if (ret != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return nullptr; | |||
| } | |||
| // set graph input tensors | |||
| SetGraphInputIndex(meta_graphT); | |||
| return meta_graphT.release(); | |||
| @@ -297,11 +309,11 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod | |||
| auto abstractBase = paramNode->abstract(); | |||
| if (abstractBase == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); | |||
| return RET_ERROR; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); | |||
| return RET_ERROR; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||
| @@ -309,7 +321,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod | |||
| paramTensor->dataType = typePtr->type_id(); | |||
| if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); | |||
| return RET_ERROR; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||
| auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||
| @@ -431,13 +443,13 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch | |||
| auto ret = ConvertInputCNode(input_node, fb_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertInputCNode failed"; | |||
| return RET_ERROR; | |||
| return ret; | |||
| } | |||
| } else if (input_node->isa<Parameter>()) { | |||
| auto ret = ConvertInputParameter(input_node, meta_graphT, fb_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertInputParameter failed"; | |||
| return RET_ERROR; | |||
| return ret; | |||
| } | |||
| if (!input_node->cast<ParameterPtr>()->has_default()) { | |||
| is_graph_input = true; | |||
| @@ -24,6 +24,7 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "ir/func_graph.h" | |||
| #include "tools/converter/return_code.h" | |||
| namespace mindspore::lite { | |||
| class AnfExporter { | |||
| @@ -45,7 +46,7 @@ class AnfExporter { | |||
| int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| void SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *return_node); | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||
| int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| @@ -54,7 +54,7 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| char *tensor_data = new (std::nothrow) char[size]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "new char[] failed"; | |||
| return RET_ERROR; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| std::memcpy(tensor_data, tensor->data.data(), size); | |||
| param_value->set_tensor_addr(tensor_data); | |||
| @@ -128,7 +128,7 @@ int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNo | |||
| auto tuple_get_item_prim_ptr = GetTupleGetItemPrim(); | |||
| if (tuple_get_item_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||
| return RET_ERROR; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||
| auto get_item_value = NewValueNode(MakeValue<int>(i)); | |||
| @@ -153,16 +153,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| auto node = GetNode(j); | |||
| if (nullptr == node) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_ERROR; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| op_inputs.push_back(node); | |||
| } | |||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||
| new_cnode->set_fullname_with_scope(cNode->name); | |||
| auto ret = ConvertAbstract(cNode, new_cnode); | |||
| if (ret != RET_OK) { | |||
| auto status = ConvertAbstract(cNode, new_cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertAbstract failed."; | |||
| return RET_ERROR; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -176,7 +176,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||
| if (make_tuple_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||
| return RET_ERROR; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||
| make_tuple_inputs.emplace_back(make_tuple_prim); | |||
| @@ -184,7 +184,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| auto cNode = GetNode(tensor_id); | |||
| if (nullptr == cNode) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_ERROR; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| make_tuple_inputs.emplace_back(cNode); | |||
| } | |||
| @@ -195,7 +195,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| auto return_prim_ptr = GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_ERROR; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| op_inputs.emplace_back(value_node); | |||
| @@ -207,14 +207,14 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| auto return_prim_ptr = GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_ERROR; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| std::vector<AnfNodePtr> op_inputs{value_node}; | |||
| auto cnode = GetNode(meta_graph_->outputIndex.front()); | |||
| if (nullptr == cnode) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_ERROR; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| op_inputs.emplace_back(cnode); | |||
| auto return_cnode = func_graph_->NewCNode(op_inputs); | |||
| @@ -201,23 +201,23 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | |||
| bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | |||
| int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | |||
| const onnx::ValueInfoProto &value_proto) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!value_proto.has_type() || !value_proto.has_name()) { | |||
| MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | |||
| return false; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| node->set_name(value_proto.name()); | |||
| const auto &type_proto = value_proto.type(); | |||
| if (!type_proto.has_tensor_type()) { | |||
| MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; | |||
| return false; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); | |||
| if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { | |||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; | |||
| return false; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); | |||
| std::vector<int> shape; | |||
| @@ -227,7 +227,7 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod | |||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | |||
| return false; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); | |||
| @@ -248,7 +248,7 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| delete tensor_data_buf; | |||
| delete tensor_info; | |||
| return false; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| @@ -261,10 +261,10 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod | |||
| delete tensor_info; | |||
| } | |||
| anfnode_build_map_[value_proto.name()] = node; | |||
| return true; | |||
| return RET_OK; | |||
| } | |||
| bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); | |||
| @@ -273,20 +273,22 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu | |||
| const onnx::TensorProto &initializer_proto = importProto.initializer(i); | |||
| if (!initializer_proto.has_name()) { | |||
| MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; | |||
| return false; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| default_para_map_[initializer_proto.name()] = initializer_proto; | |||
| } | |||
| int status = RET_OK; | |||
| MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | |||
| for (int i = 0; i < importProto.input_size(); ++i) { | |||
| const onnx::ValueInfoProto &input_proto = importProto.input(i); | |||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { | |||
| status = BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | |||
| return false; | |||
| break; | |||
| } | |||
| } | |||
| return true; | |||
| return status; | |||
| } | |||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| @@ -662,7 +664,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||
| return true; | |||
| } | |||
| bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| @@ -674,22 +676,25 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc | |||
| if (node_type == kConstantValueNode) { | |||
| if (!BuildValueNodeForFuncGraph(node_proto)) { | |||
| MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; | |||
| return false; | |||
| return RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); | |||
| if (cnode_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | |||
| return false; | |||
| return RET_NULL_PTR; | |||
| } | |||
| } | |||
| BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); | |||
| return true; | |||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | |||
| @@ -697,47 +702,51 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph | |||
| if (importProto.has_name()) { | |||
| debug_info_ptr->set_name(importProto.name()); | |||
| } else { | |||
| MS_LOG(ERROR) << "FuncGraph under converting has not name!"; | |||
| MS_LOG(INFO) << "FuncGraph under converting has not name!"; | |||
| } | |||
| if (!ImportParametersForGraph(outputFuncGraph, importProto)) { | |||
| return false; | |||
| auto status = ImportParametersForGraph(outputFuncGraph, importProto); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | |||
| } | |||
| bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||
| int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||
| if (!model_proto.has_producer_name()) { | |||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | |||
| return false; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| producer_name_ = model_proto.producer_name(); | |||
| if (!model_proto.has_model_version()) { | |||
| MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; | |||
| return false; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| model_version_ = model_proto.model_version(); | |||
| if (!model_proto.has_ir_version()) { | |||
| MS_LOG(ERROR) << "Parse model version from pb file failed!"; | |||
| return false; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| ir_version_ = model_proto.ir_version(); | |||
| return true; | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||
| FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | |||
| MS_EXCEPTION_IF_NULL(dstGraph); | |||
| if (!ParseModelConfigureInfo(*onnx_model_)) { | |||
| int status = ParseModelConfigureInfo(*onnx_model_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | |||
| return status; | |||
| } | |||
| const onnx::GraphProto &graphBuild = onnx_model_->graph(); | |||
| if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) { | |||
| status = BuildFuncGraph(dstGraph, graphBuild, quantType); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Build funcgraph failed!"; | |||
| func_graph_ = nullptr; | |||
| return RET_ERROR; | |||
| return status; | |||
| } | |||
| func_graph_ = dstGraph; | |||
| MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; | |||
| @@ -45,13 +45,13 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||
| int ConverterConstTensor() override { return RET_ERROR; }; | |||
| int ConverterCNode() override { return RET_ERROR; }; | |||
| int AddReturnCNode() override { return RET_ERROR; }; | |||
| bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | |||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | |||
| int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType); | |||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||
| int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType); | |||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||
| int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||
| const schema::QuantType &quantType); | |||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| @@ -61,25 +61,31 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(old_graph); | |||
| if (new_graph == nullptr) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| // quant | |||
| if (config != nullptr) { | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||
| auto bitNum = static_cast<size_t>(std::stoull(config->bitNum)); | |||
| if (bitNum != quant::UINT8_QUANTIZATION) { | |||
| MS_LOG(ERROR) << "Current Only Support 8 bit weight quant"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>( | |||
| new_graph, config->quantSize, config->convWeightQuantChannelThreshold, config->bitNum); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| } | |||
| @@ -89,6 +95,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| auto status = mQuantizer->DoQuantize(new_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Quant failed " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| @@ -97,6 +104,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| status = quant_cast.Run(new_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "add QuantCast error"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| } | |||
| @@ -23,6 +23,7 @@ | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "ir/anf.h" | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "tools/converter/return_code.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -66,6 +66,8 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| if (flag->fmk == converter::FmkType_MS) { | |||
| MS_ASSERT(nullptr != modelImporter); | |||
| modelImporter->Import(flag->quantType); | |||
| int status = modelImporter->Import(flag->quantType); | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| graph = modelImporter->GetResult(); | |||
| } else { | |||
| MS_ASSERT(nullptr != modelParser); | |||
| @@ -94,8 +96,9 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| transform->SetGraphDef(meta_graph); | |||
| transform->CreateQuantizer(flag); | |||
| auto status = transform->Transform(*flag); | |||
| if (status != 0) { | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Transform meta graph failed " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| @@ -106,15 +109,16 @@ int RunConverter(int argc, const char **argv) { | |||
| std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags); | |||
| if (flags == nullptr) { | |||
| MS_LOG(ERROR) << "new flags error "; | |||
| return RET_ERROR; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| auto status = flags->Init(argc, argv); | |||
| if (status == RET_SUCCESS_EXIT) { | |||
| return 0; | |||
| return status; | |||
| } | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "converter::Flags Init failed: " << status; | |||
| return 1; | |||
| std::cout << "CONVERTER::FLAGS INIT FAILED" << std::endl; | |||
| return status; | |||
| } | |||
| // Load graph | |||
| std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1); | |||
| @@ -147,9 +151,11 @@ int RunConverter(int argc, const char **argv) { | |||
| return 1; | |||
| } | |||
| } | |||
| status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); | |||
| if (fb_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert model return nullptr"; | |||
| return 1; | |||
| std::cout << "CONVERT RESULT: FAILED!" << std::endl; | |||
| return status; | |||
| } | |||
| // save graph to file | |||
| @@ -158,13 +164,14 @@ int RunConverter(int argc, const char **argv) { | |||
| status = storage.Save(*fb_graph, flags->outputFile); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "Save graph failed"; | |||
| return 1; | |||
| std::cout << "SAVE GRAPH FAILED!" << std::endl; | |||
| return RET_ERROR; | |||
| } | |||
| delete fb_graph; | |||
| MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!"; | |||
| return 0; | |||
| std::cout << "CONVERT RESULT: SUCCESS!" << std::endl; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/converter/return_code.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -120,6 +120,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| status = mQuantizer->DetermineNodeQuantType(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DetermineNodeQuant failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| @@ -142,7 +143,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_ERROR; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| formatTransPass->SetQuantType(ctx.quantType); | |||
| formatTransPass->SetFmk(ctx.fmk); | |||
| @@ -154,7 +155,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_ERR) { | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| @@ -196,7 +197,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||
| return RET_ERROR; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.inferenceType); | |||
| @@ -117,6 +117,13 @@ STATUS InferShapePass::Run(MetaGraphT *graph) { | |||
| if (ret == RET_INFER_INVALID) { | |||
| MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name | |||
| << ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type) << "flag set to false."; | |||
| for (auto input_tensor : input_tensors) { | |||
| delete input_tensor; | |||
| } | |||
| for (auto output_tensor : output_tensors) { | |||
| delete output_tensor; | |||
| } | |||
| return RET_INFER_INVALID; | |||
| } else if (ret != RET_OK) { | |||
| MS_LOG(WARNING) << "InferShape failed, name: " << node->name | |||
| << ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type); | |||
| @@ -22,7 +22,7 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/anf_importer/import_from_meta_graphT.h" | |||
| #include "ir/anf.h" | |||
| #include "include/errorcode.h" | |||
| #include "tools/converter/return_code.h" | |||
| namespace mindspore::lite { | |||
| using namespace schema; | |||
| @@ -35,8 +35,12 @@ class ModelParser { | |||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) { | |||
| auto *meta_graph = ParseToFb(modelFile, weightFile, quantType); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "parse model to fb failed"; | |||
| return nullptr; | |||
| } | |||
| auto func_graph = this->Fb2Anf(meta_graph); | |||
| delete (meta_graph); | |||
| delete(meta_graph); | |||
| return func_graph; | |||
| } | |||
| @@ -48,9 +52,10 @@ class ModelParser { | |||
| MS_EXCEPTION_IF_NULL(meta_graph); | |||
| auto func_graph = std::make_shared<FuncGraph>(); | |||
| AnfImporterFromMetaGraphT importer(meta_graph, func_graph); | |||
| auto ret = importer.Import(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << ret; | |||
| auto status = importer.Import(); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| return func_graph; | |||
| @@ -52,7 +52,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) { | |||
| for (auto &opDef : graphDefT->nodes) { | |||
| for (auto pass : this->nodePasses) { | |||
| status = pass->Run(new GraphNode(graphDefT, opDef.get())); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_ERR) { | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run NodePass failed"; | |||
| return status; | |||
| } else { | |||
| @@ -65,7 +65,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) { | |||
| for (auto pass : this->graphPasses) { | |||
| status = pass->Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_ERR) { | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run GraphPass failed"; | |||
| return status; | |||
| } else { | |||
| @@ -31,4 +31,5 @@ add_library(caffe_parser_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_tanh_parser.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_exp_parser.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_slice_parser.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_reduce_parser.cc | |||
| ) | |||
| @@ -38,7 +38,16 @@ STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| return RET_NULL_PTR; | |||
| } | |||
| // set default params | |||
| attr->outMaxValue = false; | |||
| attr->topK = 1; | |||
| const caffe::ArgMaxParameter argmaxParam = proto.argmax_param(); | |||
| if (argmaxParam.has_out_max_val()) { | |||
| attr->outMaxValue = argmaxParam.out_max_val(); | |||
| } | |||
| if (argmaxParam.has_top_k()) { | |||
| attr->topK = argmaxParam.top_k(); | |||
| } | |||
| int32_t axisType; | |||
| int32_t axis = 0; | |||
| if (!argmaxParam.has_axis()) { | |||
| @@ -46,15 +55,9 @@ STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| } else { | |||
| axisType = 1; | |||
| axis = (int64_t)argmaxParam.axis(); | |||
| if (axis == -1) { | |||
| MS_LOG(ERROR) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| attr->axis = axis; | |||
| attr->axisType = axisType; | |||
| attr->outMaxValue = argmaxParam.out_max_val(); | |||
| attr->topK = argmaxParam.top_k(); | |||
| attr->keepDims = true; | |||
| op->name = proto.name(); | |||
| @@ -33,18 +33,23 @@ const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; | |||
| schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { | |||
| int status = ValidateFileStr(modelFile, ".prototxt"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| if (weightFile.empty()) { | |||
| MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||
| return nullptr; | |||
| } | |||
| if (ValidateFileStr(weightFile, ".caffemodel") != RET_OK) { | |||
| status = ValidateFileStr(weightFile, ".caffemodel"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| @@ -52,33 +57,40 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co | |||
| TensorCache tensorCache; | |||
| caffe::NetParameter proto; | |||
| if (ReadProtoFromText((const char *)modelFile.c_str(), &proto) != RET_OK) { | |||
| status = ReadProtoFromText((const char *)modelFile.c_str(), &proto); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| metaGraph->name = proto.name(); | |||
| caffe::NetParameter weight; | |||
| if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) { | |||
| status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto status = GetModelInput(proto, &tensorCache); | |||
| status = GetModelInput(proto, &tensorCache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GetModelInput failed " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ParseLayer(proto, weight, &tensorCache, metaGraph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseLayer failed " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = SetGraphTensorIndex(proto, &tensorCache, metaGraph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Set inputTensor index and outputTensor index for graph failed!"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| metaGraph->name = GetModelName(modelFile); | |||
| @@ -148,7 +160,12 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T | |||
| } | |||
| for (auto iter : caffeInspector.GetGraphOutput()) { | |||
| int index = tensorCache->FindTensor(iter); | |||
| int index = -1; | |||
| if (splitLayer.find(iter) != splitLayer.end()) { | |||
| index = tensorCache->FindTensor(splitLayer.find(iter)->second); | |||
| } else { | |||
| index = tensorCache->FindTensor(iter); | |||
| } | |||
| if (index >= 0) { | |||
| subGraphDef->outputIndex.emplace_back(index); | |||
| } else { | |||
| @@ -199,26 +216,28 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff | |||
| op->name = layer.name(); | |||
| if (layer.type() == "Split") { | |||
| splitLayer.emplace(layer.name(), layer.bottom(0)); | |||
| for (int j = 0; j < layer.top_size(); ++j) { | |||
| splitLayer.emplace(layer.top(j), layer.bottom(0)); | |||
| } | |||
| continue; | |||
| } | |||
| auto status = SetOpInputIdx(layer, op.get(), tensorCache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; | |||
| return RET_ERROR; | |||
| return status; | |||
| } | |||
| auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); | |||
| if (nodeParser == nullptr) { | |||
| MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name(); | |||
| return RET_ERROR; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<schema::TensorT *> weightVec; | |||
| status = nodeParser->Parse(layer, layerP, op.get(), &weightVec); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; | |||
| return RET_ERROR; | |||
| return status; | |||
| } | |||
| SetWeightTensor(weightVec, op.get(), tensorCache); | |||
| @@ -226,7 +245,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff | |||
| status = SetOpOutputIdx(layer, op.get(), tensorCache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; | |||
| return RET_ERROR; | |||
| return status; | |||
| } | |||
| // op->fmkType = FmkType_CAFFE; | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/caffe/caffe_reduce_parser.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, | |||
| const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeReduceParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::ReductionParameter reduce_param = proto.reduction_param(); | |||
| if (reduce_param.has_operation()) { | |||
| switch (reduce_param.operation()) { | |||
| case caffe::ReductionParameter_ReductionOp_MEAN: | |||
| attr->mode = schema::ReduceMode_ReduceMean; | |||
| break; | |||
| case caffe::ReductionParameter_ReductionOp_SUM: | |||
| attr->mode = schema::ReduceMode_ReduceSum; | |||
| break; | |||
| case caffe::ReductionParameter_ReductionOp_SUMSQ: | |||
| attr->mode = schema::ReduceMode_ReduceSumSquare; | |||
| break; | |||
| case caffe::ReductionParameter_ReductionOp_ASUM: | |||
| attr->mode = schema::ReduceMode_ReduceASum; | |||
| default: | |||
| MS_LOG(ERROR) << "reduce parse params fail, unsupported opration: " << reduce_param.operation(); | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| attr->mode = schema::ReduceMode_ReduceSum; | |||
| } | |||
| if (reduce_param.has_axis()) { | |||
| attr->axes = std::vector(1, reduce_param.axis()); | |||
| } else { | |||
| attr->axes = std::vector(1, 0); | |||
| } | |||
| attr->reduceToEnd = true; | |||
| attr->keepDims = false; | |||
| op->name = proto.name(); | |||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_REDUCE_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_REDUCE_PARSER_H | |||
| #include <vector> | |||
| #include "tools/converter/parser/caffe/caffe_node_parser.h" | |||
| #include "tools/converter/parser/caffe/caffe_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class CaffeReduceParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeReduceParser() : CaffeNodeParser("reduce") {} | |||
| STATUS Parse(const caffe::LayerParameter &proto, | |||
| const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_REDUCE_PARSER_H | |||
| @@ -548,8 +548,15 @@ STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_LOG(ERROR) << "mslite don't support tanh now"; | |||
| return RET_ERROR; | |||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->type = schema::ActivationType_TANH; | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); | |||
| @@ -458,14 +458,18 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) | |||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { | |||
| int status = ValidateFileStr(modelFile, ".onnx"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| onnx::ModelProto onnx_model; | |||
| if (ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model) != RET_OK) { | |||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| const onnx::GraphProto &onnx_graph = onnx_model.graph(); | |||
| @@ -475,19 +479,25 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||
| // find out input names and const names | |||
| FindGraphInputAndConst(onnx_graph); | |||
| // set const tensor | |||
| if (SetGraphConstTensor(onnx_graph, &tensor_cache)) { | |||
| status = SetGraphConstTensor(onnx_graph, &tensor_cache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphConstTensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto dst_graph = std::make_unique<schema::MetaGraphT>(); | |||
| // init onnx model graph input tensor | |||
| if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { | |||
| status = SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphInputTensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| // init onnx model graph output tensor | |||
| if (SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { | |||
| status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphOutputTensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| // init op node input/output tensor, and dst_op attr | |||
| @@ -499,9 +509,10 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); | |||
| continue; | |||
| } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | |||
| auto status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); | |||
| status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| continue; | |||
| @@ -509,9 +520,10 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||
| std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | |||
| std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>(); | |||
| auto status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache); | |||
| status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| dst_graph->nodes.emplace_back(std::move(dst_op)); | |||
| @@ -42,11 +42,29 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| attr->format = schema::Format_NCHW; | |||
| std::vector<int64_t> shape; | |||
| shape.clear(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "shape") { | |||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||
| shape.push_back(static_cast<int64_t>(onnx_node_attr.ints(i))); | |||
| if (onnx_node.input_size() != 2) { | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "shape") { | |||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||
| shape.push_back(static_cast<int64_t>(onnx_node_attr.ints(i))); | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| onnx::TensorProto input_shape; | |||
| const auto &shape_name = onnx_node.input(1); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == shape_name) { | |||
| input_shape = it; | |||
| break; | |||
| } | |||
| } | |||
| if (input_shape.int64_data_size() == 0) { | |||
| MS_LOG(WARNING) << "shape maybe from another op other than const initializer"; | |||
| } else { | |||
| for (int i = 0; i < input_shape.int64_data_size(); ++i) { | |||
| shape.push_back(input_shape.int64_data(i)); | |||
| } | |||
| } | |||
| } | |||
| @@ -43,26 +43,10 @@ STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tflite_op->inputs.empty()) { | |||
| MS_LOG(ERROR) << "the input is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_index = tflite_op->inputs[0]; | |||
| const auto &data_tensor = tflite_tensors[data_index]; | |||
| if (data_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "the input tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto ndim = data_tensor->shape.size(); | |||
| std::vector<int32_t> axis; | |||
| axis.reserve(ndim); | |||
| for (size_t i = 0; i < ndim; i++) { | |||
| axis.emplace_back(i); | |||
| } | |||
| attr->axis = axis; | |||
| attr->epsilon = 0.0f; | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions(); | |||
| attr->axis = {-1}; | |||
| attr->epsilon = 1e-6f; | |||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_L2Norm; | |||
| op->primitive->value.value = attr.release(); | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_lsh_projection_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteLshProjectionParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) { | |||
| MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::LshProjectionT> attr = std::make_unique<schema::LshProjectionT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions(); | |||
| switch (tflite_attr->type) { | |||
| case tflite::LSHProjectionType_SPARSE: | |||
| attr->type = schema::LshProjectionType_SPARSE; | |||
| break; | |||
| case tflite::LSHProjectionType_DENSE: | |||
| attr->type = schema::LshProjectionType_DENSE; | |||
| break; | |||
| default: | |||
| attr->type = schema::LshProjectionType_UNKNOWN; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LshProjection; | |||
| op->primitive->value.value = attr.release(); | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TfliteLshProjectionParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} | |||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H | |||
| @@ -56,11 +56,11 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr< | |||
| if (memcpy_s(tensor->data.data(), tensor->data.size(), tflite_model_buffer[buffer_idx]->data.data(), | |||
| tflite_model_buffer[buffer_idx]->data.size())) { | |||
| MS_LOG(ERROR) << "memcpy tensor data failed"; | |||
| return RET_ERROR; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "src tensor data is empty"; | |||
| return RET_ERROR; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -77,7 +77,8 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor | |||
| } | |||
| // change quant param min to 0 to fit ms-lite ops | |||
| if (tensor->dataType == TypeId::kNumberTypeInt8) { | |||
| if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 | |||
| && tensor->dataType == TypeId::kNumberTypeInt8) { | |||
| quant_param->zeroPoint = quant_param->zeroPoint - 128; | |||
| } | |||
| @@ -114,12 +115,13 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str(); | |||
| return RET_NULL_PTR; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| if (node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, | |||
| &tensorsFormat, &tensorsIdMap) != RET_OK) { | |||
| int status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, | |||
| &tensorsFormat, &tensorsIdMap); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | |||
| return RET_ERROR; | |||
| return status; | |||
| } | |||
| sub_graph->nodes.emplace_back(op.release()); | |||
| @@ -158,7 +160,11 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> | |||
| auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer); | |||
| auto isConst = (!tensor_buffer->data.empty()); | |||
| if (isConst) { | |||
| CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); | |||
| int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "obtain const tensor failed"; | |||
| return status; | |||
| } | |||
| } else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| // set in/out tensor to int8 to fit ms-lite op | |||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||
| @@ -204,6 +210,9 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> | |||
| auto iter = tensorsIdMap.find(id); | |||
| if (iter != tensorsIdMap.end()) { | |||
| graph_inputs.push_back(iter->second); | |||
| } else { | |||
| MS_LOG(ERROR) << "get graph input failed"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end()); | |||
| @@ -220,6 +229,9 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> | |||
| auto iter = tensorsIdMap.find(id); | |||
| if (iter != tensorsIdMap.end()) { | |||
| graph_outputs.push_back(iter->second); | |||
| } else { | |||
| MS_LOG(ERROR) << "get graph output failed"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end()); | |||
| @@ -306,11 +318,13 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, | |||
| auto tflite_model = ReadTfliteModel(model_file.c_str()); | |||
| if (tflite_model == nullptr) { | |||
| MS_LOG(ERROR) << "read tflite model failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||
| return nullptr; | |||
| } | |||
| if (tflite_model->subgraphs.size() != 1) { | |||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||
| return nullptr; | |||
| } | |||
| const auto &tflite_subgraph = tflite_model->subgraphs[0]; | |||
| @@ -318,31 +332,40 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, | |||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "new meta graph failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| meta_graph->name = "MS_model converted by TF-Lite"; | |||
| quantType = quant_type; | |||
| // convert op | |||
| if (ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()) != RET_OK) { | |||
| int status = ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "parse op failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| // convert tensor | |||
| if (ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()) != RET_OK) { | |||
| status = ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert tensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| // set graph input/output | |||
| if (GetGraphInfo(tflite_subgraph, meta_graph.get()) != RET_OK) { | |||
| status = GetGraphInfo(tflite_subgraph, meta_graph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert tensors failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| // update for depthwiseConv | |||
| if (ConvertGroupDepthwiseOp(meta_graph.get()) != RET_OK) { | |||
| status = ConvertGroupDepthwiseOp(meta_graph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert group depthwise conv failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_RETURN_CODE_H | |||
| #define LITE_RETURN_CODE_H | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ReturnCode { | |||
| public: | |||
| ~ReturnCode() {} | |||
| static ReturnCode *GetSingleReturnCode() { | |||
| static ReturnCode returnCode; | |||
| return &returnCode; | |||
| } | |||
| void UpdateReturnCode(STATUS status) { | |||
| if (statusCode == RET_OK) { | |||
| statusCode = status; | |||
| } | |||
| } | |||
| STATUS GetReturnCode() { | |||
| return statusCode; | |||
| } | |||
| private: | |||
| ReturnCode() { statusCode = RET_OK; } | |||
| int statusCode; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_RETURN_CODE_H | |||
| @@ -79,7 +79,7 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { | |||
| return 0; | |||
| } | |||
| } | |||
| void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) { | |||
| int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) { | |||
| AnfNodePtr conv_bias_node = nullptr; | |||
| AnfNodePtr conv_weight_node = nullptr; | |||
| if (conv_node->inputs().size() == kConvNoBiasLen) { | |||
| @@ -93,11 +93,12 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c | |||
| auto kernel_nums = Get_Kenrnel_nums(conv_node); | |||
| if (kernel_nums <= 0) { | |||
| MS_LOG(EXCEPTION) << "kernel num less than 0"; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| } | |||
| auto add_bias_data = new (std::nothrow) float[kernel_nums]; | |||
| if (add_bias_data == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| return; | |||
| return lite::RET_MEMORY_FAILED; | |||
| } | |||
| auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); | |||
| CheckIfNodeIsParam(bias_add_weight); | |||
| @@ -112,6 +113,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c | |||
| } else { | |||
| if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { | |||
| MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed"; | |||
| return lite::RET_MEMORY_FAILED; | |||
| } | |||
| } | |||
| if (conv_bias_node != nullptr) { | |||
| @@ -120,6 +122,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c | |||
| auto conv_bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_bias_param); | |||
| if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) { | |||
| MS_LOG(EXCEPTION) << "conv_bias_node shape error"; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| } | |||
| auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->tensor_addr()); | |||
| for (int i = 0; i < kernel_nums; i++) { | |||
| @@ -133,6 +136,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c | |||
| conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); | |||
| conv_node->add_input(conv_new_bias); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace | |||
| const BaseRef ConvBiasaddFusion::DefinePattern() const { | |||
| @@ -159,7 +163,11 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons | |||
| } | |||
| auto conv_node = conv_node_anf->cast<CNodePtr>(); | |||
| CheckIfCNodeIsNull(conv_node); | |||
| GenConvNewBias(func_graph, conv_node, add_node); | |||
| int ret = GenConvNewBias(func_graph, conv_node, add_node); | |||
| if (ret != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return nullptr; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitive_c != nullptr); | |||
| auto type = primitive_c->Type(); | |||
| @@ -180,6 +188,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons | |||
| primc->SetHasBias(true); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType, " << type; | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return nullptr; | |||
| } | |||
| return conv_node; | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "tools/converter/return_code.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||