| @@ -89,7 +89,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||||
| } | } | ||||
| int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | ||||
| const std::shared_ptr<PrimitiveC> primitive, | |||||
| const std::shared_ptr<PrimitiveC> &primitive, | |||||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | const std::unique_ptr<schema::CNodeT> &dst_node) { | ||||
| MS_ASSERT(meta_graph != nullptr); | MS_ASSERT(meta_graph != nullptr); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| @@ -173,7 +173,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & | |||||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *return_node) { | schema::CNodeT *return_node) { | ||||
| MS_ASSERT(nullptr != meta_graph); | |||||
| MS_ASSERT(nullptr != meta_graphT); | |||||
| MS_ASSERT(nullptr != return_node); | MS_ASSERT(nullptr != return_node); | ||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | for (size_t i = 1; i < cnode->inputs().size(); i++) { | ||||
| auto input_node = cnode->input(i); | auto input_node = cnode->input(i); | ||||
| @@ -191,8 +191,8 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < return_node->inputIndex.size(); ++i) { | |||||
| meta_graphT->outputIndex.push_back(return_node->inputIndex[i]); | |||||
| for (unsigned int &i : return_node->inputIndex) { | |||||
| meta_graphT->outputIndex.push_back(i); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -272,7 +272,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||||
| return meta_graphT.release(); | return meta_graphT.release(); | ||||
| } | } | ||||
| int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode) { | |||||
| int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) { | |||||
| std::string input_name = input_anode->fullname_with_scope(); | std::string input_name = input_anode->fullname_with_scope(); | ||||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | auto input_cnode = utils::cast<CNodePtr>(input_anode); | ||||
| @@ -336,7 +336,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, | |||||
| int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *output_cnode) { | schema::CNodeT *output_cnode) { | ||||
| auto paramNode = input_anode->cast<ParameterPtr>(); | auto paramNode = input_anode->cast<ParameterPtr>(); | ||||
| @@ -382,7 +382,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *output_cnode) { | schema::CNodeT *output_cnode) { | ||||
| auto valueNode = input_anode->cast<ValueNodePtr>(); | auto valueNode = input_anode->cast<ValueNodePtr>(); | ||||
| @@ -478,7 +478,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *fb_node) { | schema::CNodeT *fb_node) { | ||||
| MS_ASSERT(nullptr != meta_graph); | |||||
| MS_ASSERT(nullptr != meta_graphT); | |||||
| MS_ASSERT(nullptr != fb_node); | MS_ASSERT(nullptr != fb_node); | ||||
| if (cnode->inputs().size() <= 1) { | if (cnode->inputs().size() <= 1) { | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -518,14 +518,14 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch | |||||
| void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *fb_node) { | schema::CNodeT *fb_node) { | ||||
| MS_ASSERT(nullptr != graph); | |||||
| MS_ASSERT(nullptr != meta_graphT); | |||||
| MS_ASSERT(nullptr != fb_node); | MS_ASSERT(nullptr != fb_node); | ||||
| std::string cnode_name = fb_node->name; | std::string cnode_name = fb_node->name; | ||||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | ||||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | ||||
| for (size_t i = 0; i < tuple->size(); i++) { | for (size_t i = 0; i < tuple->size(); i++) { | ||||
| auto msTensor = new schema::TensorT(); | |||||
| auto msTensor = new (std::nothrow) schema::TensorT(); | |||||
| msTensor->nodeType = schema::NodeType_CNode; | msTensor->nodeType = schema::NodeType_CNode; | ||||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| @@ -552,7 +552,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||||
| #endif | #endif | ||||
| } | } | ||||
| } else { | } else { | ||||
| auto ms_tensor = new schema::TensorT(); | |||||
| auto ms_tensor = new (std::nothrow) schema::TensorT(); | |||||
| ms_tensor->nodeType = schema::NodeType_CNode; | ms_tensor->nodeType = schema::NodeType_CNode; | ||||
| ms_tensor->dataType = TypeId::kNumberTypeFloat32; | ms_tensor->dataType = TypeId::kNumberTypeFloat32; | ||||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| @@ -36,21 +36,22 @@ class AnfExporter { | |||||
| schema::CNodeT *fb_node); | schema::CNodeT *fb_node); | ||||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *fb_node); | schema::CNodeT *fb_node); | ||||
| void RemoveIfMakeTuple(const CNodePtr &cnode); | |||||
| void RemoveIfDepend(const CNodePtr &cnode); | |||||
| static void RemoveIfMakeTuple(const CNodePtr &cnode); | |||||
| static void RemoveIfDepend(const CNodePtr &cnode); | |||||
| protected: | protected: | ||||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode); | |||||
| int ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, | |||||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | |||||
| int ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | ||||
| int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | ||||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | ||||
| int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *return_node); | schema::CNodeT *return_node); | ||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||||
| int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||||
| const std::shared_ptr<PrimitiveC> primitive, const std::unique_ptr<schema::CNodeT> &dst_node); | |||||
| static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||||
| static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||||
| const std::shared_ptr<PrimitiveC> &primitive, | |||||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||||
| private: | private: | ||||
| std::map<std::string, int> node_id_map_; | std::map<std::string, int> node_id_map_; | ||||
| @@ -62,4 +63,4 @@ class AnfExporter { | |||||
| // and clear. | // and clear. | ||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| @@ -15,8 +15,6 @@ | |||||
| */ | */ | ||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/anf_importer/anf_importer.h" | #include "tools/anf_importer/anf_importer.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -51,4 +51,4 @@ class AnfImporter { | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||||
| @@ -22,7 +22,6 @@ | |||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "tools/common/tensor_util.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int AnfImporterFromMetaGraphT::ConverterConstTensor() { | int AnfImporterFromMetaGraphT::ConverterConstTensor() { | ||||
| @@ -31,11 +30,9 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||||
| for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | ||||
| auto &tensor = meta_graph_->allTensors.at(i); | auto &tensor = meta_graph_->allTensors.at(i); | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| // converter weight and graph input into parameter node | |||||
| if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { | if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| MS_ASSERT(tensor->dims() != nullptr); | |||||
| auto parameter = func_graph_->add_parameter(); | auto parameter = func_graph_->add_parameter(); | ||||
| std::vector<int> shape(tensor->dims.size()); | std::vector<int> shape(tensor->dims.size()); | ||||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | ||||
| @@ -45,11 +42,12 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | ||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | [](const int32_t &value) { return static_cast<int64_t>(value); }); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | ||||
| MS_ASSERT(nullptr != abstract_tensor); | |||||
| parameter->set_abstract(abstract_tensor); | parameter->set_abstract(abstract_tensor); | ||||
| parameter->set_name("const_" + std::to_string(i) + "_parameter"); | parameter->set_name("const_" + std::to_string(i) + "_parameter"); | ||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | ||||
| MS_ASSERT(param_value != nullptr); | |||||
| MS_ASSERT(nullptr != param_value); | |||||
| param_value->set_tensor_shape(shape); | param_value->set_tensor_shape(shape); | ||||
| param_value->set_tensor_type(type_id); | param_value->set_tensor_type(type_id); | ||||
| param_value->set_format(tensor->format); | param_value->set_format(tensor->format); | ||||
| @@ -123,7 +121,9 @@ abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTe | |||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | ||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | [](const int32_t &value) { return static_cast<int64_t>(value); }); | ||||
| return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| auto ptr = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| MS_ASSERT(nullptr != ptr); | |||||
| return ptr; | |||||
| } | } | ||||
| int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, | int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, | ||||
| @@ -175,15 +175,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> op_inputs = {anf_primitive}; | std::vector<AnfNodePtr> op_inputs = {anf_primitive}; | ||||
| for (unsigned int j : cNode->inputIndex) { | |||||
| for (int j : cNode->inputIndex) { | |||||
| auto node = GetNode(j); | auto node = GetNode(j); | ||||
| if (nullptr == node) { | if (nullptr == node) { | ||||
| MS_LOG(ERROR) << "Can't find input node."; | MS_LOG(ERROR) << "Can't find input node."; | ||||
| return RET_ERROR; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| op_inputs.push_back(node); | op_inputs.push_back(node); | ||||
| } | } | ||||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | auto new_cnode = func_graph_->NewCNode(op_inputs); | ||||
| MS_ASSERT(nullptr != new_cnode); | |||||
| new_cnode->set_fullname_with_scope(cNode->name); | new_cnode->set_fullname_with_scope(cNode->name); | ||||
| auto status = ConvertAbstract(cNode, new_cnode); | auto status = ConvertAbstract(cNode, new_cnode); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -195,10 +196,8 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||||
| } | } | ||||
| int AnfImporterFromMetaGraphT::AddReturnCNode() { | int AnfImporterFromMetaGraphT::AddReturnCNode() { | ||||
| if (meta_graph_ == nullptr || func_graph_ == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph or func_graph is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_ASSERT(nullptr != meta_graph_); | |||||
| MS_ASSERT(nullptr != func_graph_); | |||||
| if (meta_graph_->outputIndex.size() > 1) { | if (meta_graph_->outputIndex.size() > 1) { | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | auto make_tuple_prim_ptr = GetMakeTuplePrim(); | ||||
| @@ -229,6 +228,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||||
| op_inputs.emplace_back(value_node); | op_inputs.emplace_back(value_node); | ||||
| op_inputs.emplace_back(make_tuple_cnode); | op_inputs.emplace_back(make_tuple_cnode); | ||||
| auto cnode = func_graph_->NewCNode(op_inputs); | auto cnode = func_graph_->NewCNode(op_inputs); | ||||
| MS_ASSERT(nullptr != cnode); | |||||
| cnode->set_fullname_with_scope("return"); | cnode->set_fullname_with_scope("return"); | ||||
| func_graph_->set_return(cnode); | func_graph_->set_return(cnode); | ||||
| } else { | } else { | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -40,7 +40,9 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||||
| int ConverterCNode() override; | int ConverterCNode() override; | ||||
| ValueNodePtr ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode); | ValueNodePtr ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode); | ||||
| abstract::AbstractTensorPtr ConvertTensorToAbstractTensor(const std::unique_ptr<schema::TensorT> &tensor); | |||||
| static abstract::AbstractTensorPtr ConvertTensorToAbstractTensor(const std::unique_ptr<schema::TensorT> &tensor); | |||||
| int ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, const CNodePtr &dst_cnode); | int ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, const CNodePtr &dst_cnode); | ||||
| int AddReturnCNode() override; | int AddReturnCNode() override; | ||||
| @@ -51,4 +53,4 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||||
| @@ -239,7 +239,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node | |||||
| node->set_abstract(abstract_tensor); | node->set_abstract(abstract_tensor); | ||||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | ||||
| auto *tensor_info = new Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| auto *tensor_info = new (std::nothrow) Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| if (tensor_info == nullptr) { | if (tensor_info == nullptr) { | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| @@ -345,7 +345,6 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| return {}; | |||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | ||||
| @@ -871,7 +870,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| } | } | ||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | ||||
| auto onnx_model = new onnx::ModelProto; | |||||
| auto onnx_model = new (std::nothrow) onnx::ModelProto; | |||||
| if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID); | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| @@ -81,4 +81,4 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||||
| @@ -24,7 +24,6 @@ Option<std::string> FlagParser::ParseFlags(int argc, const char *const *argv, bo | |||||
| bool supportDuplicate) { | bool supportDuplicate) { | ||||
| MS_ASSERT(argv != nullptr); | MS_ASSERT(argv != nullptr); | ||||
| const int FLAG_PREFIX_LEN = 2; | const int FLAG_PREFIX_LEN = 2; | ||||
| // Get binary name | |||||
| binName = GetFileName(argv[0]); | binName = GetFileName(argv[0]); | ||||
| std::multimap<std::string, Option<std::string>> keyValues; | std::multimap<std::string, Option<std::string>> keyValues; | ||||
| @@ -45,9 +44,7 @@ Option<std::string> FlagParser::ParseFlags(int argc, const char *const *argv, bo | |||||
| Option<std::string> value = Option<std::string>(None()); | Option<std::string> value = Option<std::string>(None()); | ||||
| size_t pos = flagItem.find_first_of('='); | size_t pos = flagItem.find_first_of('='); | ||||
| if (pos == std::string::npos && flagItem.find("--no-") != std::string::npos) { | |||||
| key = flagItem.substr(FLAG_PREFIX_LEN); | |||||
| } else if (pos == std::string::npos) { | |||||
| if (pos == std::string::npos) { | |||||
| key = flagItem.substr(FLAG_PREFIX_LEN); | key = flagItem.substr(FLAG_PREFIX_LEN); | ||||
| } else { | } else { | ||||
| key = flagItem.substr(FLAG_PREFIX_LEN, pos - FLAG_PREFIX_LEN); | key = flagItem.substr(FLAG_PREFIX_LEN, pos - FLAG_PREFIX_LEN); | ||||
| @@ -81,10 +78,10 @@ bool FlagParser::GetRealFlagName(std::string *flagName, const std::string &oriFl | |||||
| // Inner parse function | // Inner parse function | ||||
| Option<std::string> FlagParser::InnerParseFlags(std::multimap<std::string, Option<std::string>> *keyValues) { | Option<std::string> FlagParser::InnerParseFlags(std::multimap<std::string, Option<std::string>> *keyValues) { | ||||
| MS_ASSERT(keyValues != nullptr); | MS_ASSERT(keyValues != nullptr); | ||||
| for (auto it = keyValues->begin(); it != keyValues->end(); ++it) { | |||||
| for (auto &keyValue : *keyValues) { | |||||
| std::string flagName; | std::string flagName; | ||||
| bool opaque = GetRealFlagName(&flagName, (*it).first); | |||||
| Option<std::string> flagValue = (*it).second; | |||||
| bool opaque = GetRealFlagName(&flagName, keyValue.first); | |||||
| Option<std::string> flagValue = keyValue.second; | |||||
| auto item = flags.find(flagName); | auto item = flags.find(flagName); | ||||
| if (item == flags.end()) { | if (item == flags.end()) { | ||||
| @@ -133,7 +130,7 @@ Option<std::string> FlagParser::InnerParseFlags(std::multimap<std::string, Optio | |||||
| return Option<std::string>(None()); | return Option<std::string>(None()); | ||||
| } | } | ||||
| void Replaceall(std::string *str, const std::string &oldValue, const std::string &newValue) { | |||||
| void ReplaceAll(std::string *str, const std::string &oldValue, const std::string &newValue) { | |||||
| if (str == nullptr) { | if (str == nullptr) { | ||||
| MS_LOG(ERROR) << "Input str is nullptr"; | MS_LOG(ERROR) << "Input str is nullptr"; | ||||
| return; | return; | ||||
| @@ -153,9 +150,9 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const { | |||||
| std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; | std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; | ||||
| // usage of bin name | // usage of bin name | ||||
| usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | ||||
| // help line of help message, usageLine:message of parametors | |||||
| std::string helpLine = ""; | |||||
| std::string usageLine = ""; | |||||
| // help line of help message, usageLine:message of parameters | |||||
| std::string helpLine; | |||||
| std::string usageLine; | |||||
| uint32_t i = 0; | uint32_t i = 0; | ||||
| for (auto flag = flags.begin(); flag != flags.end(); flag++) { | for (auto flag = flags.begin(); flag != flags.end(); flag++) { | ||||
| std::string flagName = flag->second.flagName; | std::string flagName = flag->second.flagName; | ||||
| @@ -165,7 +162,7 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const { | |||||
| if (++i <= flags.size()) { | if (++i <= flags.size()) { | ||||
| // add parameter help message of each line | // add parameter help message of each line | ||||
| thisLine += " " + helpInfo; | thisLine += " " + helpInfo; | ||||
| Replaceall(&helpInfo, "\n\r", "\n"); | |||||
| ReplaceAll(&helpInfo, "\n\r", "\n"); | |||||
| usageLine += thisLine + "\n"; | usageLine += thisLine + "\n"; | ||||
| } else { | } else { | ||||
| // breif help message | // breif help message | ||||
| @@ -14,21 +14,18 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_COMMON_FLAG_PARSER_H_ | |||||
| #define PREDICT_COMMON_FLAG_PARSER_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H | |||||
| #include <functional> | #include <functional> | ||||
| #include <map> | #include <map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "tools/common/option.h" | #include "tools/common/option.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| struct FlagInfo; | |||||
| struct Nothing {}; | struct Nothing {}; | ||||
| class FlagParser { | class FlagParser { | ||||
| @@ -44,6 +41,7 @@ class FlagParser { | |||||
| template <typename Flags, typename T1, typename T2> | template <typename Flags, typename T1, typename T2> | ||||
| void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); | void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); | ||||
| template <typename Flags, typename T1, typename T2> | template <typename Flags, typename T1, typename T2> | ||||
| void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); | void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); | ||||
| @@ -94,7 +92,7 @@ class FlagParser { | |||||
| Option<std::string> InnerParseFlags(std::multimap<std::string, Option<std::string>> *values); | Option<std::string> InnerParseFlags(std::multimap<std::string, Option<std::string>> *values); | ||||
| bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName); | |||||
| static bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName); | |||||
| std::map<std::string, FlagInfo> flags; | std::map<std::string, FlagInfo> flags; | ||||
| }; | }; | ||||
| @@ -181,7 +179,7 @@ void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string | |||||
| FlagInfo flagItem; | FlagInfo flagItem; | ||||
| // flagItem is as a output parameter | |||||
| // flagItem is as an output parameter | |||||
| ConstructFlag(t1, flagName, helpInfo, flagItem); | ConstructFlag(t1, flagName, helpInfo, flagItem); | ||||
| flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> { | flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> { | ||||
| if (base != nullptr) { | if (base != nullptr) { | ||||
| @@ -301,4 +299,4 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_COMMON_FLAG_PARSER_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H | |||||
| @@ -15,8 +15,7 @@ | |||||
| */ | */ | ||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include <stdlib.h> | |||||
| #include <time.h> | |||||
| #include <ctime> | |||||
| #include <utility> | #include <utility> | ||||
| #include <set> | #include <set> | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| @@ -29,7 +28,10 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| OpDefCopyer GetSimpleOpCopyer() { | OpDefCopyer GetSimpleOpCopyer() { | ||||
| return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> { | return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> { | ||||
| std::unique_ptr<CNodeT> newCNode(new CNodeT); | |||||
| std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>(); | |||||
| if (newCNode == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| newCNode->name = inCNode->name; | newCNode->name = inCNode->name; | ||||
| newCNode->quantType = inCNode->quantType; | newCNode->quantType = inCNode->quantType; | ||||
| @@ -163,8 +165,6 @@ STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) { | |||||
| } | } | ||||
| } | } | ||||
| // whether need to remove weightInputTensores | |||||
| // remove all node's outputTensors | |||||
| RemoveTensor(graphT, outputTensorIdxes); | RemoveTensor(graphT, outputTensorIdxes); | ||||
| node->inputIndex.clear(); | node->inputIndex.clear(); | ||||
| node->outputIndex.clear(); | node->outputIndex.clear(); | ||||
| @@ -183,8 +183,11 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool remove | |||||
| MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| CNodeT *node = graphT->nodes.at(nodeIdx).get(); | CNodeT *node = graphT->nodes.at(nodeIdx).get(); | ||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "node is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto inputTensorIdxes = node->inputIndex; | auto inputTensorIdxes = node->inputIndex; | ||||
| auto outputTensorIdxes = node->outputIndex; | auto outputTensorIdxes = node->outputIndex; | ||||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); | auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); | ||||
| @@ -244,6 +247,7 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTe | |||||
| size_t nodeIdx = 0; | size_t nodeIdx = 0; | ||||
| for (size_t i = 0; i < graphT->nodes.size(); i++) { | for (size_t i = 0; i < graphT->nodes.size(); i++) { | ||||
| auto &inNode = graphT->nodes.at(i); | auto &inNode = graphT->nodes.at(i); | ||||
| MS_ASSERT(inNode != nullptr); | |||||
| if (inNode->name == node->name) { | if (inNode->name == node->name) { | ||||
| isSubNode = true; | isSubNode = true; | ||||
| nodeIdx = i; | nodeIdx = i; | ||||
| @@ -259,6 +263,7 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTe | |||||
| } | } | ||||
| STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) { | STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) { | ||||
| MS_ASSERT(graphT != nullptr); | |||||
| for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { | for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { | ||||
| uint32_t deleteIdx = *iter; | uint32_t deleteIdx = *iter; | ||||
| if (!forceDelete) { | if (!forceDelete) { | ||||
| @@ -297,6 +302,7 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe | |||||
| } | } | ||||
| STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) { | STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) { | ||||
| MS_ASSERT(node != nullptr); | |||||
| for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { | for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { | ||||
| if (*inIdxIt == deleteIdx) { | if (*inIdxIt == deleteIdx) { | ||||
| inIdxIt = node->inputIndex.erase(inIdxIt); | inIdxIt = node->inputIndex.erase(inIdxIt); | ||||
| @@ -330,6 +336,7 @@ STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ | |||||
| graphT->allTensors.emplace_back(std::move(tensor)); | graphT->allTensors.emplace_back(std::move(tensor)); | ||||
| uint32_t newTensorIdx = graphT->allTensors.size() - 1; | uint32_t newTensorIdx = graphT->allTensors.size() - 1; | ||||
| auto node = graphT->nodes.at(nodeIdx).get(); | auto node = graphT->nodes.at(nodeIdx).get(); | ||||
| MS_ASSERT(node != nullptr); | |||||
| if (place == kBefore) { | if (place == kBefore) { | ||||
| node->inputIndex.emplace_back(newTensorIdx); | node->inputIndex.emplace_back(newTensorIdx); | ||||
| } else { | } else { | ||||
| @@ -340,11 +347,13 @@ STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ | |||||
| STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, | STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, | ||||
| std::unique_ptr<TensorT> tensor) { | std::unique_ptr<TensorT> tensor) { | ||||
| MS_ASSERT(graphT != nullptr); | |||||
| if (nodeIdx >= graphT->nodes.size()) { | if (nodeIdx >= graphT->nodes.size()) { | ||||
| MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| auto node = graphT->nodes.at(nodeIdx).get(); | auto node = graphT->nodes.at(nodeIdx).get(); | ||||
| MS_ASSERT(node != nullptr); | |||||
| if (inTensorIdx >= graphT->allTensors.size()) { | if (inTensorIdx >= graphT->allTensors.size()) { | ||||
| MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx; | MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx; | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| @@ -358,7 +367,9 @@ STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_ | |||||
| } | } | ||||
| NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, | NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, | ||||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) { | |||||
| MS_ASSERT(graphT != nullptr); | |||||
| MS_ASSERT(errorCode != nullptr); | |||||
| if (existNodeIdx >= graphT->nodes.size()) { | if (existNodeIdx >= graphT->nodes.size()) { | ||||
| MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx; | MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx; | ||||
| return graphT->nodes.end(); | return graphT->nodes.end(); | ||||
| @@ -370,7 +381,9 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPla | |||||
| } | } | ||||
| NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | ||||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) { | |||||
| MS_ASSERT(graphT != nullptr); | |||||
| MS_ASSERT(errorCode != nullptr); | |||||
| if (place == kBefore) { | if (place == kBefore) { | ||||
| return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); | return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); | ||||
| } else if (place == kAfter) { | } else if (place == kAfter) { | ||||
| @@ -382,7 +395,9 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPl | |||||
| } | } | ||||
| NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, | NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, | ||||
| std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||||
| std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, const OpDefCopyer &opDefCopyer) { | |||||
| MS_ASSERT(graphT != nullptr); | |||||
| MS_ASSERT(errorCode != nullptr); | |||||
| auto &existNode = *existNodeIter; | auto &existNode = *existNodeIter; | ||||
| MS_ASSERT(existNode != nullptr); | MS_ASSERT(existNode != nullptr); | ||||
| MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx); | MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx); | ||||
| @@ -390,7 +405,7 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||||
| auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx); | auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx); | ||||
| MS_ASSERT(graphT->allTensors.size() > preTensorIdx); | MS_ASSERT(graphT->allTensors.size() > preTensorIdx); | ||||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx); | |||||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx); | |||||
| if (preNodeIdxes.empty()) { | if (preNodeIdxes.empty()) { | ||||
| auto &preTensor = graphT->allTensors.at(preTensorIdx); | auto &preTensor = graphT->allTensors.at(preTensorIdx); | ||||
| MS_ASSERT(preTensor != nullptr); | MS_ASSERT(preTensor != nullptr); | ||||
| @@ -402,9 +417,12 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||||
| } | } | ||||
| preTensor->refCount = 0; | preTensor->refCount = 0; | ||||
| preTensor->data.clear(); | preTensor->data.clear(); | ||||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | ||||
| preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||||
| MS_ASSERT(prim != nullptr); | |||||
| preTensor->dataType = prim->srcT; | |||||
| toAddTensor->dataType = prim->dstT; | |||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | size_t toAddTensorIdx = graphT->allTensors.size() - 1; | ||||
| @@ -438,9 +456,12 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||||
| MS_LOG(ERROR) << "Copy TensorT failed"; | MS_LOG(ERROR) << "Copy TensorT failed"; | ||||
| return graphT->nodes.end(); | return graphT->nodes.end(); | ||||
| } | } | ||||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | ||||
| preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||||
| MS_ASSERT(prim != nullptr); | |||||
| preTensor->dataType = prim->srcT; | |||||
| toAddTensor->dataType = prim->dstT; | |||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | size_t toAddTensorIdx = graphT->allTensors.size() - 1; | ||||
| @@ -473,7 +494,10 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||||
| } | } | ||||
| NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, | NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, | ||||
| std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||||
| std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, | |||||
| const OpDefCopyer &opDefCopyer) { | |||||
| MS_ASSERT(graphT != nullptr); | |||||
| MS_ASSERT(errorCode != nullptr); | |||||
| auto &existNode = *existNodeIter; | auto &existNode = *existNodeIter; | ||||
| MS_ASSERT(existNode != nullptr); | MS_ASSERT(existNode != nullptr); | ||||
| MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx); | MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx); | ||||
| @@ -481,7 +505,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||||
| auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx); | auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx); | ||||
| MS_ASSERT(graphT->allTensors.size() > postTensorIdx); | MS_ASSERT(graphT->allTensors.size() > postTensorIdx); | ||||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx); | |||||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx); | |||||
| if (postNodeIdxes.empty()) { | if (postNodeIdxes.empty()) { | ||||
| auto &postTensor = graphT->allTensors.at(postTensorIdx); | auto &postTensor = graphT->allTensors.at(postTensorIdx); | ||||
| MS_ASSERT(postTensor != nullptr); | MS_ASSERT(postTensor != nullptr); | ||||
| @@ -491,9 +515,12 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||||
| *errorCode = RET_NULL_PTR; | *errorCode = RET_NULL_PTR; | ||||
| return graphT->nodes.end(); | return graphT->nodes.end(); | ||||
| } | } | ||||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | ||||
| postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||||
| MS_ASSERT(prim != nullptr); | |||||
| postTensor->dataType = prim->srcT; | |||||
| toAddTensor->dataType = prim->dstT; | |||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | size_t toAddTensorIdx = graphT->allTensors.size() - 1; | ||||
| @@ -554,9 +581,12 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||||
| *errorCode = RET_NULL_PTR; | *errorCode = RET_NULL_PTR; | ||||
| return graphT->nodes.end(); | return graphT->nodes.end(); | ||||
| } | } | ||||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | ||||
| postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||||
| MS_ASSERT(prim != nullptr); | |||||
| postTensor->dataType = prim->srcT; | |||||
| toAddTensor->dataType = prim->dstT; | |||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | size_t toAddTensorIdx = graphT->allTensors.size() - 1; | ||||
| @@ -589,13 +619,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||||
| return existNodeIter; | return existNodeIter; | ||||
| } | } | ||||
| STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) { | |||||
| if (modelFile.size() > fileType.size()) { | |||||
| if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) { | |||||
| return RET_OK; | |||||
| } else { | |||||
| return RET_ERROR; | |||||
| } | |||||
| STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) { | |||||
| if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) { | |||||
| return RET_OK; | |||||
| } else { | } else { | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_PREDICT_GRAPH_UTIL_H | |||||
| #define MINDSPORE_PREDICT_GRAPH_UTIL_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| @@ -73,19 +72,19 @@ STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_ | |||||
| NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, | NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, | ||||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, | std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, | ||||
| OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); | |||||
| const OpDefCopyer &opDefCopyer = GetSimpleOpCopyer()); | |||||
| NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | ||||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, | std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, | ||||
| OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); | |||||
| const OpDefCopyer &opDefCopyer = GetSimpleOpCopyer()); | |||||
| NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, | NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, | ||||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); | |||||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer); | |||||
| NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, | NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, | ||||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); | |||||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer); | |||||
| STATUS ValidateFileStr(const std::string &modelFile, std::string fileType); | |||||
| STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType); | |||||
| void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); | void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); | ||||
| @@ -97,4 +96,4 @@ std::string GetModelName(const std::string &modelFile); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_PREDICT_GRAPH_UTIL_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||||
| @@ -160,6 +160,7 @@ std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; } | |||||
| STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims, | STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims, | ||||
| mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) { | mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) { | ||||
| MS_ASSERT(nullptr != dst_dims); | |||||
| if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { | if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { | ||||
| MS_LOG(ERROR) << "Convert format , src size " << src_dims.size() | MS_LOG(ERROR) << "Convert format , src size " << src_dims.size() | ||||
| << " <3 or src format is equal to dst format,not need convert"; | << " <3 or src format is equal to dst format,not need convert"; | ||||
| @@ -189,7 +190,7 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (nchw_dim.size() == 0) { | |||||
| if (nchw_dim.empty()) { | |||||
| MS_LOG(ERROR) << "Param nchw_dim is empty!"; | MS_LOG(ERROR) << "Param nchw_dim is empty!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -215,6 +216,10 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v | |||||
| STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | ||||
| int32_t *filterH, int32_t *filterW) { | int32_t *filterH, int32_t *filterW) { | ||||
| if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) { | |||||
| MS_LOG(ERROR) << "null input"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_ASSERT(oriDims.size() == 4); | MS_ASSERT(oriDims.size() == 4); | ||||
| if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { | if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { | ||||
| *filterK = oriDims.at(KCHW_K); | *filterK = oriDims.at(KCHW_K); | ||||
| @@ -282,6 +287,7 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt | |||||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | ||||
| if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<int32_t> oriDims = tensor->dims; | std::vector<int32_t> oriDims = tensor->dims; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_PREDICT_NODE_UTIL_H | |||||
| #define MINDSPORE_PREDICT_NODE_UTIL_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -60,13 +60,6 @@ class NodeUtils { | |||||
| public: | public: | ||||
| static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format, | static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format, | ||||
| std::vector<int32_t> *dst_dims); | std::vector<int32_t> *dst_dims); | ||||
| static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin, | |||||
| int64_t out_dim, int64_t stride); | |||||
| static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int32_t> &input_dims, | |||||
| std::vector<int32_t> &begin, std::vector<int32_t> &output_dims, | |||||
| schema::TensorT *output, std::vector<int32_t> &stride); | |||||
| }; | }; | ||||
| enum kTransFilterType { | enum kTransFilterType { | ||||
| @@ -133,7 +126,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in | |||||
| if (type == kCHWK2HWCK) { | if (type == kCHWK2HWCK) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | ||||
| } else if (type == kCHWK2KHWC) { | |||||
| } else { | |||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | ||||
| } | } | ||||
| @@ -334,4 +327,4 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) | |||||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); | STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_PREDICT_NODE_UTIL_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_COMMON_OPTION_H_ | |||||
| #define PREDICT_COMMON_OPTION_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_OPTION_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_OPTION_H | |||||
| #include <type_traits> | #include <type_traits> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -56,7 +56,7 @@ class Option { | |||||
| } | } | ||||
| } | } | ||||
| virtual ~Option() {} | |||||
| virtual ~Option() = default; | |||||
| bool IsNone() const { return state == NONE; } | bool IsNone() const { return state == NONE; } | ||||
| @@ -116,4 +116,4 @@ class Option { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_COMMON_OPTION_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_OPTION_H | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_PROTOBUF_UTILS_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_PROTOBUF_UTILS_H | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -35,4 +35,4 @@ STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *mess | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_PROTOBUF_UTILS_H | |||||
| @@ -50,7 +50,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath | |||||
| } | } | ||||
| schema::MetaGraphT *Storage::Load(const std::string &inputPath) { | schema::MetaGraphT *Storage::Load(const std::string &inputPath) { | ||||
| size_t size; | |||||
| size_t size = 0; | |||||
| auto buf = ReadFile(inputPath.c_str(), &size); | auto buf = ReadFile(inputPath.c_str(), &size); | ||||
| if (buf == nullptr) { | if (buf == nullptr) { | ||||
| MS_LOG(ERROR) << "the file buffer is nullptr"; | MS_LOG(ERROR) << "the file buffer is nullptr"; | ||||
| @@ -58,7 +58,7 @@ schema::MetaGraphT *Storage::Load(const std::string &inputPath) { | |||||
| } | } | ||||
| flatbuffers::Verifier verify((const uint8_t *)buf, size); | flatbuffers::Verifier verify((const uint8_t *)buf, size); | ||||
| if (false == schema::VerifyMetaGraphBuffer(verify)) { | |||||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||||
| MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph"; | MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_COMMON_STORAGE_H_ | |||||
| #define PREDICT_COMMON_STORAGE_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <string> | #include <string> | ||||
| @@ -27,11 +27,11 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| class Storage { | class Storage { | ||||
| public: | public: | ||||
| int Save(const schema::MetaGraphT &graph, const std::string &outputPath); | |||||
| static int Save(const schema::MetaGraphT &graph, const std::string &outputPath); | |||||
| schema::MetaGraphT *Load(const std::string &inputPath); | |||||
| static schema::MetaGraphT *Load(const std::string &inputPath); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_COMMON_STORAGE_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H | |||||
| @@ -14,7 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <cfloat> | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_PREDICT_TENSOR_UTIL_H | |||||
| #define MINDSPORE_PREDICT_TENSOR_UTIL_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H | |||||
| #define MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| @@ -58,13 +58,11 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem | |||||
| std::unique_ptr<schema::QuantParamT> CopyQuantParamArrayT( | std::unique_ptr<schema::QuantParamT> CopyQuantParamArrayT( | ||||
| const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray); | const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray); | ||||
| using MSGraphDefTPtr = std::shared_ptr<schema::MetaGraphT>; | |||||
| enum Category { CONST = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 }; | enum Category { CONST = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 }; | ||||
| class TensorCache { | class TensorCache { | ||||
| public: | public: | ||||
| TensorCache() {} | |||||
| TensorCache() = default; | |||||
| ~TensorCache() { tensors.clear(); } | ~TensorCache() { tensors.clear(); } | ||||
| @@ -97,12 +95,12 @@ class TensorCache { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| void UpdateTensorIndex(const std::string &name, int index) { | |||||
| void UpdateTensorIndex(const std::string &name, int idx) { | |||||
| auto iter = tensorIndex.find(name); | auto iter = tensorIndex.find(name); | ||||
| if (iter != tensorIndex.end()) { | if (iter != tensorIndex.end()) { | ||||
| tensorIndex[name] = index; | |||||
| tensorIndex[name] = idx; | |||||
| } else { | } else { | ||||
| tensorIndex.insert(make_pair(name, index)); | |||||
| tensorIndex.insert(make_pair(name, idx)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -120,4 +118,4 @@ class TensorCache { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_PREDICT_TENSOR_UTIL_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H | |||||
| @@ -38,17 +38,16 @@ STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| // set default params | |||||
| attr->outMaxValue = false; | attr->outMaxValue = false; | ||||
| attr->topK = 1; | attr->topK = 1; | ||||
| const caffe::ArgMaxParameter argmaxParam = proto.argmax_param(); | |||||
| const caffe::ArgMaxParameter &argmaxParam = proto.argmax_param(); | |||||
| if (argmaxParam.has_out_max_val()) { | if (argmaxParam.has_out_max_val()) { | ||||
| attr->outMaxValue = argmaxParam.out_max_val(); | attr->outMaxValue = argmaxParam.out_max_val(); | ||||
| } | } | ||||
| if (argmaxParam.has_top_k()) { | if (argmaxParam.has_top_k()) { | ||||
| attr->topK = argmaxParam.top_k(); | attr->topK = argmaxParam.top_k(); | ||||
| } | } | ||||
| int32_t axisType; | |||||
| int32_t axisType = 0; | |||||
| int32_t axis = 0; | int32_t axis = 0; | ||||
| if (!argmaxParam.has_axis()) { | if (!argmaxParam.has_axis()) { | ||||
| axisType = 2; | axisType = 2; | ||||
| @@ -26,7 +26,8 @@ namespace lite { | |||||
| class CaffeArgMaxParser : public CaffeNodeParser { | class CaffeArgMaxParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeArgMaxParser() : CaffeNodeParser("argmax") {} | CaffeArgMaxParser() : CaffeNodeParser("argmax") {} | ||||
| ~CaffeArgMaxParser() = default; | |||||
| ~CaffeArgMaxParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| }; | }; | ||||
| @@ -19,12 +19,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| #define CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT 0.00001 | |||||
| #define CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT 0.000000001 | |||||
| static const int CAFFE_BATCHNORMAL_BOTTOM_SIZE = 1; | |||||
| static const int CAFFE_BATCHNORMAL_TOP_SIZE = 1; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| using STATUS = int; | using STATUS = int; | ||||
| @@ -32,6 +26,10 @@ using STATUS = int; | |||||
| STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffeBatchNormParser"; | MS_LOG(DEBUG) << "parse CaffeBatchNormParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -48,43 +46,38 @@ STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caf | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::BatchNormParameter batchNormParam = proto.batch_norm_param(); | |||||
| // check bottom size | |||||
| if (proto.bottom_size() != CAFFE_BATCHNORMAL_BOTTOM_SIZE) { | |||||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "bottom numbers is error, it must be " | |||||
| << CAFFE_BATCHNORMAL_BOTTOM_SIZE << "but is " << proto.bottom_size(); | |||||
| const caffe::BatchNormParameter &batchNormParam = proto.batch_norm_param(); | |||||
| if (proto.bottom_size() != 1) { | |||||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "bottom numbers is error, it must be 1, but is " | |||||
| << proto.bottom_size(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // check top size | |||||
| if (proto.top_size() != CAFFE_BATCHNORMAL_TOP_SIZE) { | |||||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "top numbers is error, it must be " | |||||
| << CAFFE_BATCHNORMAL_TOP_SIZE << "but is " << proto.top_size(); | |||||
| if (proto.top_size() != 1) { | |||||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "top numbers is error, it must be 1, but is " | |||||
| << proto.top_size(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (batchNormParam.has_eps()) { | if (batchNormParam.has_eps()) { | ||||
| if (fabs(CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT - batchNormParam.eps()) < CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT) { | |||||
| attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; | |||||
| if (std::fabs(1e-5 - batchNormParam.eps()) < 1e-9) { | |||||
| attr->epsilon = 1e-5; | |||||
| } else { | } else { | ||||
| auto tmpAuto = batchNormParam.eps(); | auto tmpAuto = batchNormParam.eps(); | ||||
| attr->epsilon = tmpAuto; | attr->epsilon = tmpAuto; | ||||
| } | } | ||||
| } else { | } else { | ||||
| attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; | |||||
| attr->epsilon = 1e-5; | |||||
| } | } | ||||
| const float blob2Data = | const float blob2Data = | ||||
| (weight.blobs(2).double_data_size() > 0) ? weight.blobs(2).double_data(0) : weight.blobs(2).data(0); | (weight.blobs(2).double_data_size() > 0) ? weight.blobs(2).double_data(0) : weight.blobs(2).data(0); | ||||
| const float scaleFactor = blob2Data == 0 ? 0 : 1 / blob2Data; | const float scaleFactor = blob2Data == 0 ? 0 : 1 / blob2Data; | ||||
| // parse weight gamma | |||||
| auto gamma = ConvertWeight(weight.blobs(0)); | auto gamma = ConvertWeight(weight.blobs(0)); | ||||
| if (gamma == nullptr) { | if (gamma == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert blobs(0) for layer " << weight.name().c_str() << " failed"; | MS_LOG(ERROR) << "Convert blobs(0) for layer " << weight.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto estimatedMean = reinterpret_cast<float *>(gamma->data.data()); | auto estimatedMean = reinterpret_cast<float *>(gamma->data.data()); | ||||
| auto estimatedMeanShapeSize = GetShapeSize(*gamma); | auto estimatedMeanShapeSize = GetShapeSize(*gamma); | ||||
| for (size_t i = 0; i < estimatedMeanShapeSize; i++) { | for (size_t i = 0; i < estimatedMeanShapeSize; i++) { | ||||
| @@ -93,13 +86,11 @@ STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caf | |||||
| estimatedMean = nullptr; | estimatedMean = nullptr; | ||||
| weightVec->push_back(gamma); | weightVec->push_back(gamma); | ||||
| // parse weight beta | |||||
| auto beta = ConvertWeight(weight.blobs(1)); | auto beta = ConvertWeight(weight.blobs(1)); | ||||
| if (beta == nullptr) { | if (beta == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert blobs(1) for layer " << weight.name().c_str() << " failed"; | MS_LOG(ERROR) << "Convert blobs(1) for layer " << weight.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto estimatedVariance = reinterpret_cast<float *>(beta->data.data()); | auto estimatedVariance = reinterpret_cast<float *>(beta->data.data()); | ||||
| size_t estimatedVarianceShapeSize = GetShapeSize(*beta); | size_t estimatedVarianceShapeSize = GetShapeSize(*beta); | ||||
| for (size_t i = 0; i < estimatedVarianceShapeSize; i++) { | for (size_t i = 0; i < estimatedVarianceShapeSize; i++) { | ||||
| @@ -26,6 +26,7 @@ namespace lite { | |||||
| class CaffeBatchNormParser : public CaffeNodeParser { | class CaffeBatchNormParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeBatchNormParser() : CaffeNodeParser("batchnorm") {} | CaffeBatchNormParser() : CaffeNodeParser("batchnorm") {} | ||||
| ~CaffeBatchNormParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -17,8 +17,6 @@ | |||||
| #include "tools/converter/parser/caffe/caffe_concat_parser.h" | #include "tools/converter/parser/caffe/caffe_concat_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| const int32_t CONCAT_DEFAULT_AXIS = 1; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| @@ -40,7 +38,7 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::ConcatParameter concatParam = proto.concat_param(); | |||||
| const caffe::ConcatParameter &concatParam = proto.concat_param(); | |||||
| if (concatParam.has_axis() && concatParam.has_concat_dim()) { | if (concatParam.has_axis() && concatParam.has_concat_dim()) { | ||||
| MS_LOG(ERROR) << "Concat param in caffe have concat_dim and axis simultaneously, return fail"; | MS_LOG(ERROR) << "Concat param in caffe have concat_dim and axis simultaneously, return fail"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -48,19 +46,19 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| if (concatParam.has_concat_dim()) { | if (concatParam.has_concat_dim()) { | ||||
| MS_LOG(DEBUG) << "Concat dim , set axis: " << concatParam.concat_dim(); | MS_LOG(DEBUG) << "Concat dim , set axis: " << concatParam.concat_dim(); | ||||
| int32_t concat_dim_value = (int32_t)concatParam.concat_dim(); | |||||
| auto concat_dim_value = (int32_t)concatParam.concat_dim(); | |||||
| if (concat_dim_value < 0) { | if (concat_dim_value < 0) { | ||||
| MS_LOG(ERROR) << "concat_dim value in model is smaller than 0:" << concat_dim_value; | MS_LOG(ERROR) << "concat_dim value in model is smaller than 0:" << concat_dim_value; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->axis = concat_dim_value; | attr->axis = concat_dim_value; | ||||
| } else if (concatParam.has_axis()) { | } else if (concatParam.has_axis()) { | ||||
| MS_LOG(DEBUG) << "axis , set axis: " << concatParam.axis(); | |||||
| int32_t tmpInt = (int32_t)concatParam.axis(); | |||||
| MS_LOG(DEBUG) << "set axis: " << concatParam.axis(); | |||||
| auto tmpInt = (int32_t)concatParam.axis(); | |||||
| attr->axis = tmpInt; | attr->axis = tmpInt; | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "default , set axis: " << CONCAT_DEFAULT_AXIS; | |||||
| attr->axis = CONCAT_DEFAULT_AXIS; | |||||
| MS_LOG(DEBUG) << "by default, set axis = 1"; | |||||
| attr->axis = 1; | |||||
| } | } | ||||
| attr->n = proto.bottom_size(); | attr->n = proto.bottom_size(); | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeConcatParser : public CaffeNodeParser { | class CaffeConcatParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeConcatParser() : CaffeNodeParser("concat") {} | CaffeConcatParser() : CaffeNodeParser("concat") {} | ||||
| ~CaffeConcatParser() = default; | |||||
| ~CaffeConcatParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -17,13 +17,6 @@ | |||||
| #include "tools/converter/parser/caffe/caffe_conv_base_parser.h" | #include "tools/converter/parser/caffe/caffe_conv_base_parser.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| const uint32_t PAD_DEFAULT_VALUE = 0; | |||||
| const uint32_t STRIDE_DEFAULT_VALUE = 1; | |||||
| const uint32_t DILATION_DEFAULT_VALUE = 1; | |||||
| const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; | |||||
| const uint32_t DEFAULT_CONV_GROUP = 1; | |||||
| static const int CAFFE_CONV_BIAS_DIM_NUM = 1; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convParam, std::vector<int64_t> *pad) { | STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convParam, std::vector<int64_t> *pad) { | ||||
| @@ -40,15 +33,15 @@ STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convPar | |||||
| } | } | ||||
| if (!convParam.has_pad_h()) { | if (!convParam.has_pad_h()) { | ||||
| (*pad)[0] = PAD_DEFAULT_VALUE; | |||||
| (*pad)[1] = PAD_DEFAULT_VALUE; | |||||
| (*pad)[0] = 0; | |||||
| (*pad)[1] = 0; | |||||
| (*pad)[2] = convParam.pad_w(); | (*pad)[2] = convParam.pad_w(); | ||||
| (*pad)[3] = convParam.pad_w(); | (*pad)[3] = convParam.pad_w(); | ||||
| } else if (!convParam.has_pad_w()) { | } else if (!convParam.has_pad_w()) { | ||||
| (*pad)[0] = convParam.pad_h(); | (*pad)[0] = convParam.pad_h(); | ||||
| (*pad)[1] = convParam.pad_h(); | (*pad)[1] = convParam.pad_h(); | ||||
| (*pad)[2] = PAD_DEFAULT_VALUE; | |||||
| (*pad)[3] = PAD_DEFAULT_VALUE; | |||||
| (*pad)[2] = 0; | |||||
| (*pad)[3] = 0; | |||||
| } else { | } else { | ||||
| (*pad)[0] = convParam.pad_h(); | (*pad)[0] = convParam.pad_h(); | ||||
| (*pad)[1] = convParam.pad_h(); | (*pad)[1] = convParam.pad_h(); | ||||
| @@ -56,15 +49,14 @@ STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convPar | |||||
| (*pad)[3] = convParam.pad_w(); | (*pad)[3] = convParam.pad_w(); | ||||
| } | } | ||||
| } else { | } else { | ||||
| // default 2D | |||||
| const int num_pad_dims = convParam.pad_size(); | const int num_pad_dims = convParam.pad_size(); | ||||
| int num_spatial_dims = std::max(num_pad_dims, SPATIAL_DIM_DEFAULT_SIZE); | |||||
| int num_spatial_dims = std::max(num_pad_dims, 2); | |||||
| std::vector<int64_t> vec; | std::vector<int64_t> vec; | ||||
| vec.reserve(num_spatial_dims); | |||||
| for (int i = 0; i < num_spatial_dims; ++i) { | for (int i = 0; i < num_spatial_dims; ++i) { | ||||
| vec.push_back((num_pad_dims == 0) ? PAD_DEFAULT_VALUE : convParam.pad((num_pad_dims == 1) ? 0 : i)); | |||||
| vec.push_back((num_pad_dims == 0) ? 0 : convParam.pad((num_pad_dims == 1) ? 0 : i)); | |||||
| } | } | ||||
| // default 2D | |||||
| (*pad)[0] = vec[0]; | (*pad)[0] = vec[0]; | ||||
| (*pad)[1] = vec[0]; | (*pad)[1] = vec[0]; | ||||
| (*pad)[2] = vec[1]; | (*pad)[2] = vec[1]; | ||||
| @@ -87,13 +79,13 @@ STATUS CaffeConvBaseParser::ParseStrides(const caffe::ConvolutionParameter &conv | |||||
| (*stride)[1] = convParam.stride_w(); | (*stride)[1] = convParam.stride_w(); | ||||
| } else { | } else { | ||||
| const int num_stride_dims = convParam.stride_size(); | const int num_stride_dims = convParam.stride_size(); | ||||
| int num_spatial_dims = std::max(num_stride_dims, SPATIAL_DIM_DEFAULT_SIZE); | |||||
| int num_spatial_dims = std::max(num_stride_dims, 2); | |||||
| std::vector<int64_t> vec; | std::vector<int64_t> vec; | ||||
| vec.reserve(num_spatial_dims); | |||||
| for (int i = 0; i < num_spatial_dims; ++i) { | for (int i = 0; i < num_spatial_dims; ++i) { | ||||
| vec.push_back((num_stride_dims == 0) ? STRIDE_DEFAULT_VALUE : convParam.stride((num_stride_dims == 1) ? 0 : i)); | |||||
| vec.push_back((num_stride_dims == 0) ? 1 : convParam.stride((num_stride_dims == 1) ? 0 : i)); | |||||
| } | } | ||||
| // default 2D | |||||
| (*stride)[0] = vec[0]; | (*stride)[0] = vec[0]; | ||||
| (*stride)[1] = vec[1]; | (*stride)[1] = vec[1]; | ||||
| } | } | ||||
| @@ -103,17 +95,15 @@ STATUS CaffeConvBaseParser::ParseStrides(const caffe::ConvolutionParameter &conv | |||||
| STATUS CaffeConvBaseParser::ParseDilations(const caffe::ConvolutionParameter &convParam, | STATUS CaffeConvBaseParser::ParseDilations(const caffe::ConvolutionParameter &convParam, | ||||
| std::vector<int64_t> *dilation) { | std::vector<int64_t> *dilation) { | ||||
| const int num_dilation_dims = convParam.dilation_size(); | const int num_dilation_dims = convParam.dilation_size(); | ||||
| int num_spatial_dims = std::max(num_dilation_dims, SPATIAL_DIM_DEFAULT_SIZE); | |||||
| int num_spatial_dims = std::max(num_dilation_dims, 2); | |||||
| std::vector<int64_t> vec; | std::vector<int64_t> vec; | ||||
| vec.reserve(num_spatial_dims); | |||||
| for (int i = 0; i < num_spatial_dims; ++i) { | for (int i = 0; i < num_spatial_dims; ++i) { | ||||
| vec.push_back((num_dilation_dims == 0) ? DILATION_DEFAULT_VALUE | |||||
| : convParam.dilation((num_dilation_dims == 1) ? 0 : i)); | |||||
| vec.push_back((num_dilation_dims == 0) ? 1 : convParam.dilation((num_dilation_dims == 1) ? 0 : i)); | |||||
| } | } | ||||
| // default 2D | |||||
| (*dilation)[0] = vec[0]; | (*dilation)[0] = vec[0]; | ||||
| (*dilation)[1] = vec[1]; | (*dilation)[1] = vec[1]; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -131,9 +121,11 @@ STATUS CaffeConvBaseParser::ParseKernels(const caffe::ConvolutionParameter &conv | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (convParam.kernel_size_size() != 0) { | } else if (convParam.kernel_size_size() != 0) { | ||||
| int kernel_size = convParam.kernel_size_size(); | |||||
| int num_spatial_dims = std::max(kernel_size, SPATIAL_DIM_DEFAULT_SIZE); | |||||
| const int kernel_size = convParam.kernel_size_size(); | |||||
| int num_spatial_dims = std::max(kernel_size, 2); | |||||
| std::vector<int64_t> vec; | std::vector<int64_t> vec; | ||||
| vec.reserve(num_spatial_dims); | |||||
| for (int i = 0; i < num_spatial_dims; i++) { | for (int i = 0; i < num_spatial_dims; i++) { | ||||
| vec.push_back(convParam.kernel_size((kernel_size == 1) ? 0 : i)); | vec.push_back(convParam.kernel_size((kernel_size == 1) ? 0 : i)); | ||||
| } | } | ||||
| @@ -141,24 +133,25 @@ STATUS CaffeConvBaseParser::ParseKernels(const caffe::ConvolutionParameter &conv | |||||
| (*kernel)[0] = vec[0]; | (*kernel)[0] = vec[0]; | ||||
| (*kernel)[1] = vec[1]; | (*kernel)[1] = vec[1]; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "conv does not have kernel info."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int CaffeConvBaseParser::ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType) { | int CaffeConvBaseParser::ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType) { | ||||
| // group default 1 | |||||
| int group = 0; | |||||
| if (convParam.has_group()) { | if (convParam.has_group()) { | ||||
| group = convParam.group(); | |||||
| return convParam.group(); | |||||
| } else { | } else { | ||||
| layerType == "ConvolutionDepthwise" ? (group = convParam.num_output()) : (group = DEFAULT_CONV_GROUP); | |||||
| return layerType == "ConvolutionDepthwise" ? static_cast<int>(convParam.num_output()) : 1; | |||||
| } | } | ||||
| return group; | |||||
| } | } | ||||
| int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut) { | int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut) { | ||||
| MS_ASSERT(channelOut != nullptr); | |||||
| if (channelOut == nullptr) { | |||||
| MS_LOG(ERROR) << "channelOut is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (!convParam.has_num_output()) { | if (!convParam.has_num_output()) { | ||||
| MS_LOG(ERROR) << "Parse num_output for failed."; | MS_LOG(ERROR) << "Parse num_output for failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -169,7 +162,11 @@ int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &conv | |||||
| STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | ||||
| std::vector<schema::TensorT *> *weightVec) { | std::vector<schema::TensorT *> *weightVec) { | ||||
| // Layer must have Filter | |||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (weight.blobs_size() == 0) { | if (weight.blobs_size() == 0) { | ||||
| MS_LOG(ERROR) << "No filter data in layer " << weight.name().c_str(); | MS_LOG(ERROR) << "No filter data in layer " << weight.name().c_str(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -182,8 +179,7 @@ STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | |||||
| } | } | ||||
| weightVec->push_back(filter); | weightVec->push_back(filter); | ||||
| // parse bias | |||||
| const caffe::ConvolutionParameter convParam = weight.convolution_param(); | |||||
| const caffe::ConvolutionParameter &convParam = weight.convolution_param(); | |||||
| if (convParam.bias_term() && weight.blobs_size() > 1) { | if (convParam.bias_term() && weight.blobs_size() > 1) { | ||||
| auto bias = ConvertWeight(weight.blobs(1)); | auto bias = ConvertWeight(weight.blobs(1)); | ||||
| if (bias == nullptr) { | if (bias == nullptr) { | ||||
| @@ -192,7 +188,7 @@ STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | |||||
| } | } | ||||
| std::vector<int32_t> shape = bias->dims; | std::vector<int32_t> shape = bias->dims; | ||||
| if (shape.size() != CAFFE_CONV_BIAS_DIM_NUM) { | |||||
| if (shape.size() != 1) { | |||||
| MS_LOG(ERROR) << "Bias dim-num of layer " << weight.name().c_str() << " is not supported"; | MS_LOG(ERROR) << "Bias dim-num of layer " << weight.name().c_str() << " is not supported"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -26,23 +26,23 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| class CaffeConvBaseParser { | class CaffeConvBaseParser { | ||||
| public: | public: | ||||
| CaffeConvBaseParser() {} | |||||
| CaffeConvBaseParser() = default; | |||||
| virtual ~CaffeConvBaseParser() {} | |||||
| virtual ~CaffeConvBaseParser() = default; | |||||
| STATUS ParsePads(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *pad); | |||||
| static STATUS ParsePads(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *pad); | |||||
| STATUS ParseStrides(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *stride); | |||||
| static STATUS ParseStrides(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *stride); | |||||
| STATUS ParseDilations(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *dilation); | |||||
| static STATUS ParseDilations(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *dilation); | |||||
| STATUS ParseKernels(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *kernel); | |||||
| static STATUS ParseKernels(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *kernel); | |||||
| int ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType); | |||||
| static int ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType); | |||||
| int ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut); | |||||
| static int ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut); | |||||
| STATUS ParseWeight(const caffe::LayerParameter &weight, std::vector<schema::TensorT *> *weightVec); | |||||
| static STATUS ParseWeight(const caffe::LayerParameter &weight, std::vector<schema::TensorT *> *weightVec); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -54,7 +54,10 @@ STATUS CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema: | |||||
| STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffeConvolutionParser"; | MS_LOG(DEBUG) << "parse CaffeConvolutionParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -73,11 +76,10 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| const caffe::ConvolutionParameter convParam = proto.convolution_param(); | |||||
| CaffeConvBaseParser convParser; | |||||
| const caffe::ConvolutionParameter &convParam = proto.convolution_param(); | |||||
| // parse pad | // parse pad | ||||
| std::vector<int64_t> pad(4, 0); | std::vector<int64_t> pad(4, 0); | ||||
| auto status = convParser.ParsePads(convParam, &pad); | |||||
| auto status = CaffeConvBaseParser::ParsePads(convParam, &pad); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -89,7 +91,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| // parse stride | // parse stride | ||||
| std::vector<int64_t> stride(2, 0); | std::vector<int64_t> stride(2, 0); | ||||
| status = convParser.ParseStrides(convParam, &stride); | |||||
| status = CaffeConvBaseParser::ParseStrides(convParam, &stride); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -99,7 +101,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| // parse dilation | // parse dilation | ||||
| std::vector<int64_t> dilation(2, 0); | std::vector<int64_t> dilation(2, 0); | ||||
| status = convParser.ParseDilations(convParam, &dilation); | |||||
| status = CaffeConvBaseParser::ParseDilations(convParam, &dilation); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -109,7 +111,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| // parse kernel | // parse kernel | ||||
| std::vector<int64_t> kernel(2, 0); | std::vector<int64_t> kernel(2, 0); | ||||
| status = convParser.ParseKernels(convParam, &kernel); | |||||
| status = CaffeConvBaseParser::ParseKernels(convParam, &kernel); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -118,8 +120,8 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| attr->kernelW = kernel[1]; | attr->kernelW = kernel[1]; | ||||
| attr->hasBias = convParam.bias_term(); | attr->hasBias = convParam.bias_term(); | ||||
| attr->group = convParser.ParseGroup(convParam, proto.type()); | |||||
| auto ret = convParser.ParseChannelOut(convParam, &(attr->channelOut)); | |||||
| attr->group = CaffeConvBaseParser::ParseGroup(convParam, proto.type()); | |||||
| auto ret = CaffeConvBaseParser::ParseChannelOut(convParam, &(attr->channelOut)); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "conv channel out failed"; | MS_LOG(ERROR) << "conv channel out failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -128,7 +130,6 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| if (weightBlob.has_shape()) { | if (weightBlob.has_shape()) { | ||||
| attr->channelIn = weightBlob.shape().dim(1) * attr->group; | attr->channelIn = weightBlob.shape().dim(1) * attr->group; | ||||
| } else { | } else { | ||||
| // get shape information from Blob parameters(caffe proto v1) | |||||
| attr->channelIn = weightBlob.channels() * attr->group; | attr->channelIn = weightBlob.channels() * attr->group; | ||||
| } | } | ||||
| attr->padMode = schema::PadMode_CAFFE; | attr->padMode = schema::PadMode_CAFFE; | ||||
| @@ -143,7 +144,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| status = convParser.ParseWeight(weight, weightVec); | |||||
| status = CaffeConvBaseParser::ParseWeight(weight, weightVec); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,13 +27,13 @@ namespace lite { | |||||
| class CaffeConvolutionParser : public CaffeNodeParser { | class CaffeConvolutionParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeConvolutionParser() : CaffeNodeParser("convolution") {} | CaffeConvolutionParser() : CaffeNodeParser("convolution") {} | ||||
| ~CaffeConvolutionParser() = default; | |||||
| ~CaffeConvolutionParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| private: | private: | ||||
| STATUS ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); | |||||
| static STATUS ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,13 +17,15 @@ | |||||
| #include "tools/converter/parser/caffe/caffe_crop_parser.h" | #include "tools/converter/parser/caffe/caffe_crop_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| const int32_t CROP_AXIS = 2; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffeCropParser"; | MS_LOG(DEBUG) << "parse CaffeCropParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -41,22 +43,23 @@ STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::L | |||||
| } | } | ||||
| if (!proto.has_crop_param()) { | if (!proto.has_crop_param()) { | ||||
| attr->axis = CROP_AXIS; | |||||
| attr->axis = 2; | |||||
| std::vector<int64_t> offsets(2, 0); | std::vector<int64_t> offsets(2, 0); | ||||
| attr->offsets = offsets; | attr->offsets = offsets; | ||||
| } else { | } else { | ||||
| const caffe::CropParameter cropParam = proto.crop_param(); | |||||
| const caffe::CropParameter &cropParam = proto.crop_param(); | |||||
| if (cropParam.has_axis()) { | if (cropParam.has_axis()) { | ||||
| if (cropParam.axis() == -1) { | if (cropParam.axis() == -1) { | ||||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | ||||
| } | } | ||||
| attr->axis = cropParam.axis(); | attr->axis = cropParam.axis(); | ||||
| } else { | } else { | ||||
| attr->axis = CROP_AXIS; | |||||
| attr->axis = 2; | |||||
| } | } | ||||
| if (cropParam.offset_size() != 0) { | if (cropParam.offset_size() != 0) { | ||||
| std::vector<int64_t> offsets; | std::vector<int64_t> offsets; | ||||
| offsets.reserve(cropParam.offset_size()); | |||||
| for (int i = 0; i < cropParam.offset_size(); i++) { | for (int i = 0; i < cropParam.offset_size(); i++) { | ||||
| offsets.push_back(cropParam.offset(i)); | offsets.push_back(cropParam.offset(i)); | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeCropParser : public CaffeNodeParser { | class CaffeCropParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeCropParser() : CaffeNodeParser("crop") {} | CaffeCropParser() : CaffeNodeParser("crop") {} | ||||
| ~CaffeCropParser() = default; | |||||
| ~CaffeCropParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -54,7 +54,10 @@ STATUS CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, sch | |||||
| STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffeDeconvolutionParser"; | MS_LOG(DEBUG) << "parse CaffeDeconvolutionParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -69,11 +72,10 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| attr->format = schema::Format::Format_NCHW; | attr->format = schema::Format::Format_NCHW; | ||||
| const caffe::ConvolutionParameter convParam = proto.convolution_param(); | |||||
| CaffeConvBaseParser convParser; | |||||
| const caffe::ConvolutionParameter &convParam = proto.convolution_param(); | |||||
| // parse pad | // parse pad | ||||
| std::vector<int64_t> pad(4, 0); | std::vector<int64_t> pad(4, 0); | ||||
| auto status = convParser.ParsePads(convParam, &pad); | |||||
| auto status = CaffeConvBaseParser::ParsePads(convParam, &pad); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -85,7 +87,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| // parse stride | // parse stride | ||||
| std::vector<int64_t> stride(2, 0); | std::vector<int64_t> stride(2, 0); | ||||
| status = convParser.ParseStrides(convParam, &stride); | |||||
| status = CaffeConvBaseParser::ParseStrides(convParam, &stride); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -95,7 +97,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| // parse dilation | // parse dilation | ||||
| std::vector<int64_t> dilation(2, 0); | std::vector<int64_t> dilation(2, 0); | ||||
| status = convParser.ParseDilations(convParam, &dilation); | |||||
| status = CaffeConvBaseParser::ParseDilations(convParam, &dilation); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -105,7 +107,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| // parse kernel | // parse kernel | ||||
| std::vector<int64_t> kernel(2, 0); | std::vector<int64_t> kernel(2, 0); | ||||
| status = convParser.ParseKernels(convParam, &kernel); | |||||
| status = CaffeConvBaseParser::ParseKernels(convParam, &kernel); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -114,8 +116,8 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| attr->kernelW = kernel[1]; | attr->kernelW = kernel[1]; | ||||
| attr->hasBias = convParam.bias_term(); | attr->hasBias = convParam.bias_term(); | ||||
| attr->group = convParser.ParseGroup(convParam, proto.type()); | |||||
| auto ret = convParser.ParseChannelOut(convParam, &(attr->channelOut)); | |||||
| attr->group = CaffeConvBaseParser::ParseGroup(convParam, proto.type()); | |||||
| auto ret = CaffeConvBaseParser::ParseChannelOut(convParam, &(attr->channelOut)); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "deconv channel get failed"; | MS_LOG(ERROR) << "deconv channel get failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -127,7 +129,6 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| else | else | ||||
| attr->channelIn = weightBlob.shape().dim(1) * attr->group; | attr->channelIn = weightBlob.shape().dim(1) * attr->group; | ||||
| } else { | } else { | ||||
| // get shape information from Blob parameters(caffe proto v1) | |||||
| attr->channelIn = weightBlob.num() * attr->group; | attr->channelIn = weightBlob.num() * attr->group; | ||||
| } | } | ||||
| attr->padMode = schema::PadMode_CAFFE; | attr->padMode = schema::PadMode_CAFFE; | ||||
| @@ -142,7 +143,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| status = convParser.ParseWeight(weight, weightVec); | |||||
| status = CaffeConvBaseParser::ParseWeight(weight, weightVec); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,13 +27,13 @@ namespace lite { | |||||
| class CaffeDeconvolutionParser : public CaffeNodeParser { | class CaffeDeconvolutionParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeDeconvolutionParser() : CaffeNodeParser("deconvolution") {} | CaffeDeconvolutionParser() : CaffeNodeParser("deconvolution") {} | ||||
| ~CaffeDeconvolutionParser() = default; | |||||
| ~CaffeDeconvolutionParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| private: | private: | ||||
| STATUS ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr); | |||||
| static STATUS ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,9 +18,6 @@ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <memory> | #include <memory> | ||||
| const int ELTWISE_MIN_INPUT_SIZE = 2; | |||||
| const float ELTWISE_SUM_COEFF_EPSILON = 1e-5; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| @@ -42,13 +39,13 @@ STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (proto.bottom_size() < ELTWISE_MIN_INPUT_SIZE) { | |||||
| if (proto.bottom_size() < 2) { | |||||
| MS_LOG(ERROR) << "Eltwise Op " << proto.name() << " need at least 2 inputs,but input size is " | MS_LOG(ERROR) << "Eltwise Op " << proto.name() << " need at least 2 inputs,but input size is " | ||||
| << proto.bottom_size(); | << proto.bottom_size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| const caffe::EltwiseParameter eltwiseParam = proto.eltwise_param(); | |||||
| const caffe::EltwiseParameter &eltwiseParam = proto.eltwise_param(); | |||||
| if (eltwiseParam.coeff_size() != 0 && eltwiseParam.coeff_size() != proto.bottom_size()) { | if (eltwiseParam.coeff_size() != 0 && eltwiseParam.coeff_size() != proto.bottom_size()) { | ||||
| MS_LOG(ERROR) << "Coeff size(" << eltwiseParam.coeff_size() | MS_LOG(ERROR) << "Coeff size(" << eltwiseParam.coeff_size() | ||||
| << ") check fail, Eltwise Layer takes one coefficient per bottom blob."; | << ") check fail, Eltwise Layer takes one coefficient per bottom blob."; | ||||
| @@ -60,8 +57,8 @@ STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (eltwiseParam.coeff_size() != 0 && (fabs(eltwiseParam.coeff(0) - 1) > ELTWISE_SUM_COEFF_EPSILON || | |||||
| fabs(eltwiseParam.coeff(1) - 1) > ELTWISE_SUM_COEFF_EPSILON)) { | |||||
| if (eltwiseParam.coeff_size() != 0 && | |||||
| (std::fabs(eltwiseParam.coeff(0) - 1) > 1e-5 || std::fabs(eltwiseParam.coeff(1) - 1) > 1e-5)) { | |||||
| MS_LOG(ERROR) << "Eltwise only support coefficient 1 for summation now."; | MS_LOG(ERROR) << "Eltwise only support coefficient 1 for summation now."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeEltwiseParser : public CaffeNodeParser { | class CaffeEltwiseParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeEltwiseParser() : CaffeNodeParser("eltwise") {} | CaffeEltwiseParser() : CaffeNodeParser("eltwise") {} | ||||
| ~CaffeEltwiseParser() = default; | |||||
| ~CaffeEltwiseParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -39,7 +39,7 @@ STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::La | |||||
| } | } | ||||
| if (proto.has_elu_param()) { | if (proto.has_elu_param()) { | ||||
| const caffe::ELUParameter eluParameter = proto.elu_param(); | |||||
| const caffe::ELUParameter &eluParameter = proto.elu_param(); | |||||
| if (eluParameter.has_alpha()) { | if (eluParameter.has_alpha()) { | ||||
| attr->alpha = eluParameter.alpha(); | attr->alpha = eluParameter.alpha(); | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeEluParser : public CaffeNodeParser { | class CaffeEluParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeEluParser() : CaffeNodeParser("elu") {} | CaffeEluParser() : CaffeNodeParser("elu") {} | ||||
| ~CaffeEluParser() = default; | |||||
| ~CaffeEluParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -39,7 +39,7 @@ STATUS CaffeExpParser::Parse(const caffe::LayerParameter &proto, const caffe::La | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::ExpParameter exp_param = proto.exp_param(); | |||||
| const caffe::ExpParameter &exp_param = proto.exp_param(); | |||||
| if (exp_param.has_base()) { | if (exp_param.has_base()) { | ||||
| attr->base = exp_param.base(); | attr->base = exp_param.base(); | ||||
| } else { | } else { | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeExpParser : public CaffeNodeParser { | class CaffeExpParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeExpParser() : CaffeNodeParser("exp") {} | CaffeExpParser() : CaffeNodeParser("exp") {} | ||||
| ~CaffeExpParser() = default; | |||||
| ~CaffeExpParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeFlattenParser : public CaffeNodeParser { | class CaffeFlattenParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeFlattenParser() : CaffeNodeParser("flatten") {} | CaffeFlattenParser() : CaffeNodeParser("flatten") {} | ||||
| ~CaffeFlattenParser() = default; | |||||
| ~CaffeFlattenParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -22,6 +22,10 @@ namespace lite { | |||||
| STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffeInnerProductParser"; | MS_LOG(DEBUG) << "parse CaffeInnerProductParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -38,7 +42,7 @@ STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::InnerProductParameter innerProductParam = proto.inner_product_param(); | |||||
| const caffe::InnerProductParameter &innerProductParam = proto.inner_product_param(); | |||||
| if (!innerProductParam.has_num_output()) { | if (!innerProductParam.has_num_output()) { | ||||
| MS_LOG(ERROR) << "InnerProduct Parse num_output for " << proto.name().c_str() << " failed."; | MS_LOG(ERROR) << "InnerProduct Parse num_output for " << proto.name().c_str() << " failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -62,8 +66,6 @@ STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const | |||||
| MS_LOG(ERROR) << "InnerProduct No filter data in layer " << weight.name().c_str(); | MS_LOG(ERROR) << "InnerProduct No filter data in layer " << weight.name().c_str(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // parse filter | |||||
| auto filter = ConvertWeight(weight.blobs(0)); | auto filter = ConvertWeight(weight.blobs(0)); | ||||
| if (filter == nullptr) { | if (filter == nullptr) { | ||||
| MS_LOG(ERROR) << "InnerProduct parse weight for layer " << weight.name().c_str() << " failed"; | MS_LOG(ERROR) << "InnerProduct parse weight for layer " << weight.name().c_str() << " failed"; | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeInnerProductParser : public CaffeNodeParser { | class CaffeInnerProductParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeInnerProductParser() : CaffeNodeParser("innerproduct") {} | CaffeInnerProductParser() : CaffeNodeParser("innerproduct") {} | ||||
| ~CaffeInnerProductParser() = default; | |||||
| ~CaffeInnerProductParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -47,12 +47,12 @@ STATUS CaffeInspector::ParseInput() { | |||||
| } | } | ||||
| STATUS CaffeInspector::FindInputAndOutput() { | STATUS CaffeInspector::FindInputAndOutput() { | ||||
| for (auto iter : layerBottoms) { | |||||
| for (const auto &iter : layerBottoms) { | |||||
| if (layerTops.find(iter) == layerTops.end()) { | if (layerTops.find(iter) == layerTops.end()) { | ||||
| graphInput.insert(iter); | graphInput.insert(iter); | ||||
| } | } | ||||
| } | } | ||||
| for (auto iter : layerTops) { | |||||
| for (const auto &iter : layerTops) { | |||||
| if (layerBottoms.find(iter) == layerBottoms.end()) { | if (layerBottoms.find(iter) == layerBottoms.end()) { | ||||
| graphOutput.insert(iter); | graphOutput.insert(iter); | ||||
| } | } | ||||
| @@ -62,7 +62,7 @@ STATUS CaffeInspector::FindInputAndOutput() { | |||||
| STATUS CaffeInspector::SetTopsAndBottoms() { | STATUS CaffeInspector::SetTopsAndBottoms() { | ||||
| for (int32_t i = 0; i < net.layer_size(); i++) { | for (int32_t i = 0; i < net.layer_size(); i++) { | ||||
| caffe::LayerParameter &layer = const_cast<caffe::LayerParameter &>(net.layer(i)); | |||||
| auto &layer = const_cast<caffe::LayerParameter &>(net.layer(i)); | |||||
| if (layer.top_size() == 1 && layer.bottom_size() == 1 && layer.top(0) == layer.bottom(0)) { | if (layer.top_size() == 1 && layer.bottom_size() == 1 && layer.top(0) == layer.bottom(0)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -38,7 +38,7 @@ STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::InterpParameter interpParam = proto.interp_param(); | |||||
| const caffe::InterpParameter &interpParam = proto.interp_param(); | |||||
| if (interpParam.has_height()) { | if (interpParam.has_height()) { | ||||
| int64_t height = interpParam.height(); | int64_t height = interpParam.height(); | ||||
| if (height < 0) { | if (height < 0) { | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeInterpParser : public CaffeNodeParser { | class CaffeInterpParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeInterpParser() : CaffeNodeParser("Interp") {} | CaffeInterpParser() : CaffeNodeParser("Interp") {} | ||||
| ~CaffeInterpParser() = default; | |||||
| ~CaffeInterpParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -23,6 +23,11 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | ||||
| std::unique_ptr<schema::TensorT> weight = std::make_unique<schema::TensorT>(); | std::unique_ptr<schema::TensorT> weight = std::make_unique<schema::TensorT>(); | ||||
| if (weight == nullptr) { | |||||
| MS_LOG(ERROR) << "new weight failed"; | |||||
| return nullptr; | |||||
| } | |||||
| weight->format = schema::Format::Format_NCHW; | weight->format = schema::Format::Format_NCHW; | ||||
| std::vector<int32_t> shapeVec; | std::vector<int32_t> shapeVec; | ||||
| ConvertShape(proto, &shapeVec); | ConvertShape(proto, &shapeVec); | ||||
| @@ -32,8 +37,7 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||||
| // cal Weight num | // cal Weight num | ||||
| int count = 1; | int count = 1; | ||||
| for (size_t i = 0; i < shapeVec.size(); ++i) { | |||||
| int dim = shapeVec[i]; | |||||
| for (int dim : shapeVec) { | |||||
| if (dim <= 0) { | if (dim <= 0) { | ||||
| MS_LOG(ERROR) << "Convert weight fail, Blob size invalid"; | MS_LOG(ERROR) << "Convert weight fail, Blob size invalid"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -48,6 +52,7 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||||
| // get weight | // get weight | ||||
| std::unique_ptr<float[]> buf = std::make_unique<float[]>(count); | std::unique_ptr<float[]> buf = std::make_unique<float[]>(count); | ||||
| if (buf == nullptr) { | if (buf == nullptr) { | ||||
| MS_LOG(ERROR) << "new weight buf failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (proto.double_data_size() > 0) { | if (proto.double_data_size() > 0) { | ||||
| @@ -74,6 +79,7 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||||
| << "blob.data_size:%d" << proto.data_size(); | << "blob.data_size:%d" << proto.data_size(); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| weight->data.resize(count * sizeof(float)); | weight->data.resize(count * sizeof(float)); | ||||
| const float *data_ptr = proto.data().data(); | const float *data_ptr = proto.data().data(); | ||||
| if (data_ptr == nullptr) { | if (data_ptr == nullptr) { | ||||
| @@ -91,8 +97,12 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||||
| } | } | ||||
| STATUS ConvertShape(const caffe::BlobProto &proto, std::vector<int32_t> *shape) { | STATUS ConvertShape(const caffe::BlobProto &proto, std::vector<int32_t> *shape) { | ||||
| shape->clear(); | |||||
| if (shape == nullptr) { | |||||
| MS_LOG(ERROR) << "shape is null"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| shape->clear(); | |||||
| if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { | if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { | ||||
| shape->push_back(proto.num()); | shape->push_back(proto.num()); | ||||
| shape->push_back(proto.channels()); | shape->push_back(proto.channels()); | ||||
| @@ -18,7 +18,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| CaffeNodeParserRegistry::CaffeNodeParserRegistry() {} | |||||
| CaffeNodeParserRegistry::CaffeNodeParserRegistry() = default; | |||||
| CaffeNodeParserRegistry::~CaffeNodeParserRegistry() { | CaffeNodeParserRegistry::~CaffeNodeParserRegistry() { | ||||
| for (auto ite : parsers) { | for (auto ite : parsers) { | ||||
| @@ -38,7 +38,7 @@ STATUS CaffePermuteParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::PermuteParameter permuteParam = proto.permute_param(); | |||||
| const caffe::PermuteParameter &permuteParam = proto.permute_param(); | |||||
| const int num_order_dims = permuteParam.order_size(); | const int num_order_dims = permuteParam.order_size(); | ||||
| attr->perm.resize(num_order_dims); | attr->perm.resize(num_order_dims); | ||||
| for (int i = 0; i < num_order_dims; ++i) { | for (int i = 0; i < num_order_dims; ++i) { | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffePermuteParser : public CaffeNodeParser { | class CaffePermuteParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffePermuteParser() : CaffeNodeParser("Permute") {} | CaffePermuteParser() : CaffeNodeParser("Permute") {} | ||||
| ~CaffePermuteParser() = default; | |||||
| ~CaffePermuteParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -17,9 +17,6 @@ | |||||
| #include "tools/converter/parser/caffe/caffe_pooling_parser.h" | #include "tools/converter/parser/caffe/caffe_pooling_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| const uint32_t INNERPRODUCT_WINDOW_DEFAULT_VALUE = 0; | |||||
| const uint32_t INNERPRODUCT_PAD_DEFAULT_VALUE = 0; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| @@ -43,7 +40,7 @@ STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| attr->format = schema::Format::Format_NCHW; | attr->format = schema::Format::Format_NCHW; | ||||
| const caffe::PoolingParameter poolingParam = proto.pooling_param(); | |||||
| const caffe::PoolingParameter &poolingParam = proto.pooling_param(); | |||||
| auto status = ParsePads(poolingParam, attr.get()); | auto status = ParsePads(poolingParam, attr.get()); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | ||||
| @@ -68,15 +65,12 @@ STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // default roundMode RoundMode_CEIL | |||||
| attr->roundMode = schema::RoundMode_CEIL; | attr->roundMode = schema::RoundMode_CEIL; | ||||
| if (poolingParam.has_round_mode()) { | if (poolingParam.has_round_mode()) { | ||||
| if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_FLOOR) { | if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_FLOOR) { | ||||
| attr->roundMode = schema::RoundMode_FLOOR; | attr->roundMode = schema::RoundMode_FLOOR; | ||||
| } else if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_CEIL) { | } else if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_CEIL) { | ||||
| attr->roundMode = schema::RoundMode_CEIL; | attr->roundMode = schema::RoundMode_CEIL; | ||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | } | ||||
| } | } | ||||
| attr->padMode = schema::PadMode_CAFFE; | attr->padMode = schema::PadMode_CAFFE; | ||||
| @@ -127,8 +121,8 @@ STATUS CaffePoolingParser::ParseWindows(const caffe::PoolingParameter &poolingPa | |||||
| MS_LOG(ERROR) << "With Global_pooling: true Filter size cannot specified"; | MS_LOG(ERROR) << "With Global_pooling: true Filter size cannot specified"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->windowH = INNERPRODUCT_WINDOW_DEFAULT_VALUE; | |||||
| attr->windowW = INNERPRODUCT_WINDOW_DEFAULT_VALUE; | |||||
| attr->windowH = 0; | |||||
| attr->windowW = 0; | |||||
| attr->global = true; | attr->global = true; | ||||
| } else { | } else { | ||||
| if (poolingParam.has_kernel_size() == (poolingParam.has_kernel_h() || poolingParam.has_kernel_w())) { | if (poolingParam.has_kernel_size() == (poolingParam.has_kernel_h() || poolingParam.has_kernel_w())) { | ||||
| @@ -157,7 +151,7 @@ STATUS CaffePoolingParser::ParsePoolingMode(const caffe::PoolingParameter &pooli | |||||
| } else if (poolingParam.pool() == caffe::PoolingParameter::AVE) { | } else if (poolingParam.pool() == caffe::PoolingParameter::AVE) { | ||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | attr->poolingMode = schema::PoolMode_MEAN_POOLING; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."; | |||||
| MS_LOG(ERROR) << "MindSpore support MAX and AVE PoolingMode only."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -26,18 +26,18 @@ namespace lite { | |||||
| class CaffePoolingParser : public CaffeNodeParser { | class CaffePoolingParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffePoolingParser() : CaffeNodeParser("pooling") {} | CaffePoolingParser() : CaffeNodeParser("pooling") {} | ||||
| ~CaffePoolingParser() = default; | |||||
| ~CaffePoolingParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| STATUS ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| static STATUS ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| STATUS ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| static STATUS ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| STATUS ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| static STATUS ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| STATUS ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| static STATUS ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,10 +18,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| static const float CAFFE_POWER_DEFAULT_POWER = 1.0; | |||||
| static const float CAFFE_POWER_DEFAULT_SCALE = 1.0; | |||||
| static const float CAFFE_POWER_DEFAULT_SHIFT = 0.0; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| @@ -43,15 +39,15 @@ STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::PowerParameter powerParam = proto.power_param(); | |||||
| const caffe::PowerParameter &powerParam = proto.power_param(); | |||||
| if (proto.has_power_param()) { | if (proto.has_power_param()) { | ||||
| attr->power = powerParam.has_power() ? powerParam.power() : CAFFE_POWER_DEFAULT_POWER; | |||||
| attr->scale = powerParam.has_scale() ? powerParam.scale() : CAFFE_POWER_DEFAULT_SCALE; | |||||
| attr->shift = powerParam.has_shift() ? powerParam.shift() : CAFFE_POWER_DEFAULT_SHIFT; | |||||
| attr->power = powerParam.has_power() ? powerParam.power() : 1.0; | |||||
| attr->scale = powerParam.has_scale() ? powerParam.scale() : 1.0; | |||||
| attr->shift = powerParam.has_shift() ? powerParam.shift() : 0.0; | |||||
| } else { | } else { | ||||
| attr->power = CAFFE_POWER_DEFAULT_POWER; | |||||
| attr->scale = CAFFE_POWER_DEFAULT_SCALE; | |||||
| attr->shift = CAFFE_POWER_DEFAULT_SHIFT; | |||||
| attr->power = 1.0; | |||||
| attr->scale = 1.0; | |||||
| attr->shift = 0.0; | |||||
| } | } | ||||
| op->name = proto.name(); | op->name = proto.name(); | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffePowerParser : public CaffeNodeParser { | class CaffePowerParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffePowerParser() : CaffeNodeParser("power") {} | CaffePowerParser() : CaffeNodeParser("power") {} | ||||
| ~CaffePowerParser() = default; | |||||
| ~CaffePowerParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -22,6 +22,10 @@ namespace lite { | |||||
| STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffePReluParser"; | MS_LOG(DEBUG) << "parse CaffePReluParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -38,7 +42,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::PReLUParameter pReluParam = proto.prelu_param(); | |||||
| const caffe::PReLUParameter &pReluParam = proto.prelu_param(); | |||||
| if (pReluParam.has_channel_shared()) { | if (pReluParam.has_channel_shared()) { | ||||
| attr->channelShared = pReluParam.channel_shared(); | attr->channelShared = pReluParam.channel_shared(); | ||||
| } else { | } else { | ||||
| @@ -49,7 +53,6 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| MS_LOG(ERROR) << "PRelu No blobs data in layer " << proto.name().c_str(); | MS_LOG(ERROR) << "PRelu No blobs data in layer " << proto.name().c_str(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto slope = ConvertWeight(weight.blobs(0)); | auto slope = ConvertWeight(weight.blobs(0)); | ||||
| if (slope == nullptr) { | if (slope == nullptr) { | ||||
| MS_LOG(ERROR) << "CaffePRelu convert slope for layer " << weight.name().c_str() << " failed."; | MS_LOG(ERROR) << "CaffePRelu convert slope for layer " << weight.name().c_str() << " failed."; | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffePReluParser : public CaffeNodeParser { | class CaffePReluParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffePReluParser() : CaffeNodeParser("pRelu") {} | CaffePReluParser() : CaffeNodeParser("pRelu") {} | ||||
| ~CaffePReluParser() = default; | |||||
| ~CaffePReluParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -39,7 +39,7 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::ReductionParameter reduce_param = proto.reduction_param(); | |||||
| const caffe::ReductionParameter &reduce_param = proto.reduction_param(); | |||||
| if (reduce_param.has_operation()) { | if (reduce_param.has_operation()) { | ||||
| switch (reduce_param.operation()) { | switch (reduce_param.operation()) { | ||||
| case caffe::ReductionParameter_ReductionOp_MEAN: | case caffe::ReductionParameter_ReductionOp_MEAN: | ||||
| @@ -72,6 +72,7 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| } | } | ||||
| attr->reduceToEnd = true; | attr->reduceToEnd = true; | ||||
| attr->keepDims = false; | attr->keepDims = false; | ||||
| op->name = proto.name(); | op->name = proto.name(); | ||||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | op->primitive->value.type = schema::PrimitiveType_Reduce; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeReduceParser : public CaffeNodeParser { | class CaffeReduceParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeReduceParser() : CaffeNodeParser("reduce") {} | CaffeReduceParser() : CaffeNodeParser("reduce") {} | ||||
| ~CaffeReduceParser() = default; | |||||
| ~CaffeReduceParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -39,8 +39,6 @@ STATUS CaffeRelu6Parser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| } | } | ||||
| attr->type = schema::ActivationType_RELU6; | attr->type = schema::ActivationType_RELU6; | ||||
| // relu: negative_slope = 0, no parameter; | |||||
| // leakyrelu: negative_slope != 0; | |||||
| if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) { | if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) { | ||||
| float negative_slope = proto.relu_param().negative_slope(); | float negative_slope = proto.relu_param().negative_slope(); | ||||
| if (0 != negative_slope) { | if (0 != negative_slope) { | ||||
| @@ -25,7 +25,7 @@ namespace lite { | |||||
| class CaffeRelu6Parser : public CaffeNodeParser { | class CaffeRelu6Parser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeRelu6Parser() : CaffeNodeParser("relu6") {} | CaffeRelu6Parser() : CaffeNodeParser("relu6") {} | ||||
| ~CaffeRelu6Parser() = default; | |||||
| ~CaffeRelu6Parser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -40,7 +40,7 @@ STATUS CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| attr->format = schema::Format::Format_NCHW; | attr->format = schema::Format::Format_NCHW; | ||||
| const caffe::ReshapeParameter reshapeParam = proto.reshape_param(); | |||||
| const caffe::ReshapeParameter &reshapeParam = proto.reshape_param(); | |||||
| if (!reshapeParam.has_shape()) { | if (!reshapeParam.has_shape()) { | ||||
| MS_LOG(ERROR) << "Reshape has no shape info, ret fail"; | MS_LOG(ERROR) << "Reshape has no shape info, ret fail"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeReshapeParser : public CaffeNodeParser { | class CaffeReshapeParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeReshapeParser() : CaffeNodeParser("reshape") {} | CaffeReshapeParser() : CaffeNodeParser("reshape") {} | ||||
| ~CaffeReshapeParser() = default; | |||||
| ~CaffeReshapeParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -17,14 +17,15 @@ | |||||
| #include "tools/converter/parser/caffe/caffe_scale_parser.h" | #include "tools/converter/parser/caffe/caffe_scale_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| const int32_t NCHW_DIM_C = 1; | |||||
| const int32_t DIM_DEFAULT_SIZE = 4; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| MS_LOG(DEBUG) << "parse CaffeScaleParser"; | MS_LOG(DEBUG) << "parse CaffeScaleParser"; | ||||
| if (weightVec == nullptr) { | |||||
| MS_LOG(ERROR) << "weightVec is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -47,10 +48,10 @@ STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| const caffe::ScaleParameter scaleParam = weight.scale_param(); | |||||
| int axis = NCHW_DIM_C; | |||||
| const caffe::ScaleParameter &scaleParam = weight.scale_param(); | |||||
| int axis = 1; | |||||
| if (scaleParam.has_axis()) { | if (scaleParam.has_axis()) { | ||||
| uint32_t axis_index = NCHW_DIM_C; | |||||
| uint32_t axis_index = 1; | |||||
| if (GetAxisIndex(scaleParam.axis(), &axis_index)) { | if (GetAxisIndex(scaleParam.axis(), &axis_index)) { | ||||
| MS_LOG(ERROR) << "scale get axis failed for layer " << weight.name().c_str(); | MS_LOG(ERROR) << "scale get axis failed for layer " << weight.name().c_str(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -93,7 +94,7 @@ STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| } | } | ||||
| STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) { | STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) { | ||||
| if (axis < -DIM_DEFAULT_SIZE || axis >= DIM_DEFAULT_SIZE) { | |||||
| if (axis < -4 || axis >= 4) { | |||||
| MS_LOG(ERROR) << "Scale axis value(" << axis << ") is not correct"; | MS_LOG(ERROR) << "Scale axis value(" << axis << ") is not correct"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -102,7 +103,7 @@ STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) | |||||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | ||||
| } | } | ||||
| *axis_index = (axis + DIM_DEFAULT_SIZE) % DIM_DEFAULT_SIZE; | |||||
| *axis_index = (axis + 4) % 4; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -26,12 +26,12 @@ namespace lite { | |||||
| class CaffeScaleParser : public CaffeNodeParser { | class CaffeScaleParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeScaleParser() : CaffeNodeParser("scale") {} | CaffeScaleParser() : CaffeNodeParser("scale") {} | ||||
| ~CaffeScaleParser() = default; | |||||
| ~CaffeScaleParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); | |||||
| static STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeSigmoidParser : public CaffeNodeParser { | class CaffeSigmoidParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeSigmoidParser() : CaffeNodeParser("sigmoid") {} | CaffeSigmoidParser() : CaffeNodeParser("sigmoid") {} | ||||
| ~CaffeSigmoidParser() = default; | |||||
| ~CaffeSigmoidParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -33,7 +33,6 @@ STATUS CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| } | } | ||||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -56,12 +55,12 @@ STATUS CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||||
| attr->sizeSplits = size_splits; | attr->sizeSplits = size_splits; | ||||
| } | } | ||||
| // The axis along which to slice -- may be negative to index from the end (e.g., -1 for the last axis). | |||||
| if (slice_param.has_axis()) { | if (slice_param.has_axis()) { | ||||
| attr->splitDim = slice_param.axis(); | attr->splitDim = slice_param.axis(); | ||||
| } else if (slice_param.has_slice_dim()) { | } else if (slice_param.has_slice_dim()) { | ||||
| attr->splitDim = slice_param.slice_dim(); | attr->splitDim = slice_param.slice_dim(); | ||||
| } | } | ||||
| op->name = proto.name(); | op->name = proto.name(); | ||||
| op->primitive->value.type = schema::PrimitiveType_Split; | op->primitive->value.type = schema::PrimitiveType_Split; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeSliceParser : public CaffeNodeParser { | class CaffeSliceParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeSliceParser() : CaffeNodeParser("slice") {} | CaffeSliceParser() : CaffeNodeParser("slice") {} | ||||
| ~CaffeSliceParser() = default; | |||||
| ~CaffeSliceParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -17,8 +17,6 @@ | |||||
| #include "tools/converter/parser/caffe/caffe_softmax_parser.h" | #include "tools/converter/parser/caffe/caffe_softmax_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| static const int32_t CAFFE_SOFTMAX_DEFAULT_AXIS = 1; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| @@ -42,11 +40,11 @@ STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe | |||||
| if (proto.has_softmax_param() && proto.softmax_param().has_axis()) { | if (proto.has_softmax_param() && proto.softmax_param().has_axis()) { | ||||
| if (proto.softmax_param().axis() == -1) { | if (proto.softmax_param().axis() == -1) { | ||||
| MS_LOG(ERROR) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||||
| MS_LOG(DEBUG) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||||
| } | } | ||||
| attr->axis = proto.softmax_param().axis(); | attr->axis = proto.softmax_param().axis(); | ||||
| } else { | } else { | ||||
| attr->axis = CAFFE_SOFTMAX_DEFAULT_AXIS; | |||||
| attr->axis = 1; | |||||
| } | } | ||||
| op->name = proto.name(); | op->name = proto.name(); | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeSoftmaxParser : public CaffeNodeParser { | class CaffeSoftmaxParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeSoftmaxParser() : CaffeNodeParser("softmax") {} | CaffeSoftmaxParser() : CaffeNodeParser("softmax") {} | ||||
| ~CaffeSoftmaxParser() = default; | |||||
| ~CaffeSoftmaxParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeTanhParser : public CaffeNodeParser { | class CaffeTanhParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeTanhParser() : CaffeNodeParser("tanh") {} | CaffeTanhParser() : CaffeNodeParser("tanh") {} | ||||
| ~CaffeTanhParser() = default; | |||||
| ~CaffeTanhParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -39,7 +39,7 @@ STATUS CaffeTileParser::Parse(const caffe::LayerParameter &proto, const caffe::L | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const caffe::TileParameter tile_param = proto.tile_param(); | |||||
| const caffe::TileParameter &tile_param = proto.tile_param(); | |||||
| std::vector<int> dims; | std::vector<int> dims; | ||||
| std::vector<int> multiples; | std::vector<int> multiples; | ||||
| dims.clear(); | dims.clear(); | ||||
| @@ -26,7 +26,7 @@ namespace lite { | |||||
| class CaffeTileParser : public CaffeNodeParser { | class CaffeTileParser : public CaffeNodeParser { | ||||
| public: | public: | ||||
| CaffeTileParser() : CaffeNodeParser("tile") {} | CaffeTileParser() : CaffeNodeParser("tile") {} | ||||
| ~CaffeTileParser() = default; | |||||
| ~CaffeTileParser() override = default; | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | ||||
| std::vector<schema::TensorT *> *weightVec) override; | std::vector<schema::TensorT *> *weightVec) override; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <numeric> | #include <numeric> | ||||
| #include <functional> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -266,7 +267,6 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| // there is no Prod in onnx | |||||
| if (onnx_node.op_type() == "Sum") { | if (onnx_node.op_type() == "Sum") { | ||||
| attr->mode = schema::EltwiseMode_SUM; | attr->mode = schema::EltwiseMode_SUM; | ||||
| } else if (onnx_node.op_type() == "Max") { | } else if (onnx_node.op_type() == "Max") { | ||||
| @@ -50,13 +50,6 @@ class OnnxDivParser : public OnnxNodeParser { | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class OnnxMeanParser : public OnnxNodeParser { | |||||
| public: | |||||
| OnnxMeanParser() : OnnxNodeParser("Mean") {} | |||||
| ~OnnxMeanParser() override = default; | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| }; | |||||
| class OnnxPowParser : public OnnxNodeParser { | class OnnxPowParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxPowParser() : OnnxNodeParser("Power") {} | OnnxPowParser() : OnnxNodeParser("Power") {} | ||||
| @@ -38,7 +38,6 @@ STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| // use channel dim as axis | |||||
| attr->axis = {1}; | attr->axis = {1}; | ||||
| op->primitive->value.type = schema::PrimitiveType_BiasAdd; | op->primitive->value.type = schema::PrimitiveType_BiasAdd; | ||||
| @@ -52,7 +52,7 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||||
| attr->value.push_back(static_cast<float>(onnx_node_attr.i())); | attr->value.push_back(static_cast<float>(onnx_node_attr.i())); | ||||
| break; | break; | ||||
| case onnx::AttributeProto_AttributeType_TENSOR: { | case onnx::AttributeProto_AttributeType_TENSOR: { | ||||
| auto tensor = onnx_node_attr.t(); | |||||
| const auto &tensor = onnx_node_attr.t(); | |||||
| auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| return ret; | return ret; | ||||
| @@ -67,7 +67,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| // set default params | |||||
| attr->strideH = 1; | attr->strideH = 1; | ||||
| attr->strideW = 1; | attr->strideW = 1; | ||||
| attr->dilateH = 1; | attr->dilateH = 1; | ||||
| @@ -75,6 +75,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->group = 1; | attr->group = 1; | ||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| attr->format = schema::Format::Format_NCHW; | attr->format = schema::Format::Format_NCHW; | ||||
| // set opdef each attr params | // set opdef each attr params | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| if (onnx_node_attr.name() == "group") { | if (onnx_node_attr.name() == "group") { | ||||
| @@ -157,8 +158,10 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), | auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), | ||||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | ||||
| if (iter != (*nodeIter).attribute().end()) { | if (iter != (*nodeIter).attribute().end()) { | ||||
| MS_ASSERT(iter->ints().begin() != nullptr); | |||||
| MS_ASSERT(iter->ints().end() != nullptr); | |||||
| if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | |||||
| MS_LOG(ERROR) << "dims insert failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | ||||
| } | } | ||||
| attr->channelOut = dims[0]; | attr->channelOut = dims[0]; | ||||
| @@ -28,7 +28,7 @@ class OnnxConverter : public Converter { | |||||
| public: | public: | ||||
| OnnxConverter(); | OnnxConverter(); | ||||
| ~OnnxConverter() = default; | |||||
| ~OnnxConverter() override = default; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,14 +71,12 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| // set default params | |||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| attr->group = 1; | attr->group = 1; | ||||
| attr->strideW = 1; | attr->strideW = 1; | ||||
| attr->strideH = 1; | attr->strideH = 1; | ||||
| attr->dilateW = 1; | attr->dilateW = 1; | ||||
| attr->dilateH = 1; | attr->dilateH = 1; | ||||
| // set opdef each attr params | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| if (onnx_node_attr.name() == "group") { | if (onnx_node_attr.name() == "group") { | ||||
| attr->group = static_cast<int32_t>(onnx_node_attr.i()); | attr->group = static_cast<int32_t>(onnx_node_attr.i()); | ||||
| @@ -144,10 +142,14 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| std::vector<int> weight_shape; | std::vector<int> weight_shape; | ||||
| auto size = (*nodeIter).dims_size(); | auto size = (*nodeIter).dims_size(); | ||||
| weight_shape.reserve(size); | |||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| weight_shape.emplace_back((*nodeIter).dims(i)); | weight_shape.emplace_back((*nodeIter).dims(i)); | ||||
| } | } | ||||
| MS_ASSERT(weight_shape.size() == 4); | |||||
| if (weight_shape.size() != 4) { | |||||
| MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->channelIn = weight_shape[0]; | attr->channelIn = weight_shape[0]; | ||||
| attr->channelOut = weight_shape[1] * attr->group; | attr->channelOut = weight_shape[1] * attr->group; | ||||
| @@ -31,7 +31,7 @@ class OnnxDeConvParser : public OnnxNodeParser { | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | ||||
| private: | private: | ||||
| bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); | |||||
| static bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -48,11 +48,10 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| const int64_t *dataPtr = nullptr; | |||||
| for (const auto &attrPower : nodeIter->attribute()) { | for (const auto &attrPower : nodeIter->attribute()) { | ||||
| if (attrPower.name() == "value") { | if (attrPower.name() == "value") { | ||||
| const auto &t = attrPower.t(); | const auto &t = attrPower.t(); | ||||
| dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data()); | |||||
| auto *dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data()); | |||||
| for (int i = 0; i < t.dims(0); ++i) { | for (int i = 0; i < t.dims(0); ++i) { | ||||
| dst_shape.emplace_back(dataPtr[i]); | dst_shape.emplace_back(dataPtr[i]); | ||||
| } | } | ||||
| @@ -25,7 +25,7 @@ namespace lite { | |||||
| class OnnxLpNormParser : public OnnxNodeParser { | class OnnxLpNormParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} | OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} | ||||
| ~OnnxLpNormParser() = default; | |||||
| ~OnnxLpNormParser() override = default; | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| @@ -39,7 +39,7 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| if (onnx_node_attr.name() == "direction") { | if (onnx_node_attr.name() == "direction") { | ||||
| auto direction = onnx_node_attr.s(); | |||||
| const auto &direction = onnx_node_attr.s(); | |||||
| attr->bidirection = direction == "bidirectional"; | attr->bidirection = direction == "bidirectional"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H | #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "google/protobuf/message.h" | #include "google/protobuf/message.h" | ||||
| #include "proto/onnx.pb.h" | #include "proto/onnx.pb.h" | ||||
| @@ -29,13 +30,13 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| class OnnxNodeParser { | class OnnxNodeParser { | ||||
| public: | public: | ||||
| explicit OnnxNodeParser(const std::string nodeName) : name(nodeName) {} | |||||
| explicit OnnxNodeParser(std::string nodeName) : name(std::move(nodeName)) {} | |||||
| virtual ~OnnxNodeParser() = default; | virtual ~OnnxNodeParser() = default; | ||||
| virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; | virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; | ||||
| STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | |||||
| static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | |||||
| static STATUS set_opset_version(int version) { | static STATUS set_opset_version(int version) { | ||||
| opset_version_ = version; | opset_version_ = version; | ||||
| @@ -44,9 +45,9 @@ class OnnxNodeParser { | |||||
| static int opset_version() { return opset_version_; } | static int opset_version() { return opset_version_; } | ||||
| protected: | protected: | ||||
| schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); | |||||
| static schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); | |||||
| void Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr); | |||||
| static void Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr); | |||||
| const std::string name; | const std::string name; | ||||
| @@ -40,13 +40,6 @@ OnnxNodeParser *OnnxNodeParserRegistry::GetNodeParser(const std::string &name) { | |||||
| if (it != parsers.end()) { | if (it != parsers.end()) { | ||||
| return it->second; | return it->second; | ||||
| } | } | ||||
| /* should not support vague name, otherwise may get wrong parser. ex. PRelu and Relu | |||||
| for (auto const &i : parsers) { | |||||
| if (name.find(i.first) != std::string::npos) { | |||||
| return i.second; | |||||
| } | |||||
| } | |||||
| */ | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_pool_parser.h" | #include "tools/converter/parser/onnx/onnx_pool_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_relu_parser.h" | #include "tools/converter/parser/onnx/onnx_relu_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| @@ -63,7 +62,6 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | ||||
| schema::CNodeT *op) { | schema::CNodeT *op) { | ||||
| MS_LOG(DEBUG) << "onnx PReluParser"; | MS_LOG(DEBUG) << "onnx PReluParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -113,7 +111,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | ||||
| OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxLeakeyReluParser()); | |||||
| OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxReluParser()); | |||||
| OnnxNodeRegistrar g_onnxPReluParser("PRelu", new OnnxPReluParser()); | OnnxNodeRegistrar g_onnxPReluParser("PRelu", new OnnxPReluParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,12 +30,6 @@ class OnnxReluParser : public OnnxNodeParser { | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class OnnxLeakeyReluParser : public OnnxReluParser { | |||||
| public: | |||||
| OnnxLeakeyReluParser() : OnnxReluParser() {} | |||||
| ~OnnxLeakeyReluParser() override = default; | |||||
| }; | |||||
| class OnnxPReluParser : public OnnxNodeParser { | class OnnxPReluParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxPReluParser() : OnnxNodeParser("Prelu") {} | OnnxPReluParser() : OnnxNodeParser("Prelu") {} | ||||
| @@ -43,7 +43,6 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->k = static_cast<int32_t>(onnx_node_attr.i()); | attr->k = static_cast<int32_t>(onnx_node_attr.i()); | ||||
| } | } | ||||
| } | } | ||||
| // attr->sorted; | |||||
| op->primitive->value.type = schema::PrimitiveType_TopK; | op->primitive->value.type = schema::PrimitiveType_TopK; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| @@ -41,13 +41,7 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx | |||||
| attr->conjugate = false; | attr->conjugate = false; | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "axes") { | |||||
| attr->perm.resize(onnx_node_attr.ints_size()); | |||||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||||
| attr->perm[i] = onnx_node_attr.ints(i); | |||||
| } | |||||
| } | |||||
| if (attribute_name == "perm") { | |||||
| if (attribute_name == "axes" || attribute_name == "perm") { | |||||
| attr->perm.resize(onnx_node_attr.ints_size()); | attr->perm.resize(onnx_node_attr.ints_size()); | ||||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | ||||
| attr->perm[i] = onnx_node_attr.ints(i); | attr->perm[i] = onnx_node_attr.ints(i); | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_upsample_parser.h" | #include "tools/converter/parser/onnx/onnx_upsample_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -26,6 +25,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -71,6 +73,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| attr->alpha = tflite_attr->alpha; | attr->alpha = tflite_attr->alpha; | ||||
| attr->type = schema::ActivationType_LEAKY_RELU; | attr->type = schema::ActivationType_LEAKY_RELU; | ||||
| } else { | |||||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | op->primitive->value.type = schema::PrimitiveType_Activation; | ||||
| @@ -81,12 +86,12 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | |||||
| TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | |||||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | |||||
| TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser()); | |||||
| TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); | |||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | |||||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | |||||
| TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,41 +34,6 @@ class TfliteActivationParser : public TfliteNodeParser { | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class TfliteReluParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteReluParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteRelu6Parser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteRelu6Parser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteTanhParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteTanhParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteLogisticParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteLogisticParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteSwishParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteSwishParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteHardSwishParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteHardSwishParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteLeakyReluParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteLeakyReluParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_addn_parser.h" | #include "tools/converter/parser/tflite/tflite_addn_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -26,6 +25,9 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | MS_LOG(DEBUG) << "parse TfliteAddNParser"; | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -43,11 +45,12 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| } | } | ||||
| attr->N = tflite_subgraph->tensors.size() - 1; | attr->N = tflite_subgraph->tensors.size() - 1; | ||||
| op->primitive->value.type = schema::PrimitiveType_AddN; | op->primitive->value.type = schema::PrimitiveType_AddN; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| for (int input : tflite_op->inputs) { | |||||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -25,6 +25,9 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -48,7 +51,12 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // get axis attr | // get axis attr | ||||
| auto axis_idx = tflite_op->inputs[1]; | auto axis_idx = tflite_op->inputs[1]; | ||||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||||
| auto axis_tensor = tflite_subgraph->tensors[axis_idx].get(); | |||||
| if (axis_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto buffer_idx = axis_tensor->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | auto &buf_data = tflite_model->buffers[buffer_idx]; | ||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| @@ -69,6 +77,6 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_TfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | |||||
| TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,6 +25,9 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | MS_LOG(DEBUG) << "parse TfliteArgminParser"; | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -48,7 +51,12 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // get axis attr | // get axis attr | ||||
| auto axis_idx = tflite_op->inputs[1]; | auto axis_idx = tflite_op->inputs[1]; | ||||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||||
| auto axis_tensor = tflite_subgraph->tensors[axis_idx].get(); | |||||
| if (axis_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto buffer_idx = axis_tensor->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | auto &buf_data = tflite_model->buffers[buffer_idx]; | ||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| @@ -69,6 +77,6 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_TfliteArgminParser("Argmin", new TfliteArgminParser()); | |||||
| TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -26,6 +25,9 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -165,11 +167,14 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Minimum; | op->primitive->value.type = schema::PrimitiveType_Minimum; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else { | |||||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| // set input | // set input | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| for (int input : tflite_op->inputs) { | |||||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -179,6 +184,9 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -303,6 +311,9 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Neg; | op->primitive->value.type = schema::PrimitiveType_Neg; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else { | |||||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| @@ -314,6 +325,9 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -381,45 +395,48 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_LessEqual; | op->primitive->value.type = schema::PrimitiveType_LessEqual; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else { | |||||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| for (int input : tflite_op->inputs) { | |||||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); | |||||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteSubParser()); | |||||
| TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser()); | |||||
| TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser()); | |||||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser()); | |||||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser()); | |||||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser()); | |||||
| TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser()); | |||||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteSquaredDifferenceParser()); | |||||
| TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser()); | |||||
| TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser()); | |||||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteMulParser("Mul", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteDivParser("Div", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tflitePowParser("Pow", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser()); | |||||
| TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser()); | |||||
| TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSqrtParser()); | |||||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteRsqrtParser()); | |||||
| TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSquareParser()); | |||||
| TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSinParser()); | |||||
| TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser()); | |||||
| TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | |||||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | |||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | |||||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); | |||||
| TfliteNodeRegister g_tfliteAbsParser("Abs", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteExpParser("Exp", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSquareParser("Square", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSinParser("Sin", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteCosParser("Cos", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteLogParser("Log", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | |||||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | |||||
| TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteGreaterParser()); | |||||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser()); | |||||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser()); | |||||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser()); | |||||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteCompareOpParser()); | |||||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteCompareOpParser()); | |||||
| TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteCompareOpParser()); | |||||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteCompareOpParser()); | |||||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteCompareOpParser()); | |||||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteCompareOpParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,61 +34,6 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class TfliteAddParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteAddParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSubParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteSubParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteMulParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteMulParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteDivParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteDivParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteFloorDivParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteFloorDivParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteFloorModParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteFloorModParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSquaredDifferenceParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteSquaredDifferenceParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteRealDivParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteRealDivParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TflitePowParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TflitePowParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteMaximumParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteMaximumParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteMinimumParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteMinimumParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSingleInputOpParser : public TfliteNodeParser { | class TfliteSingleInputOpParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -98,66 +43,6 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class TfliteAbsParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteAbsParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteExpParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteExpParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSqrtParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteSqrtParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSquareParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteSquareParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSinParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteSinParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCosParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteCosParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteRsqrtParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteRsqrtParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteLogParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteLogParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteRoundParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteRoundParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCeilParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteCeilParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteFloorParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteFloorParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteNegParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteNegParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCompareOpParser : public TfliteNodeParser { | class TfliteCompareOpParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -166,36 +51,6 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class TfliteEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteNotEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteNotEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteGreaterParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteGreaterParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteGreaterEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteGreaterEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteLessParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteLessParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteLessEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteLessEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,6 +27,9 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| MS_ASSERT(tflite_op != nullptr); | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -44,6 +47,9 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | ||||
| } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | ||||
| } else { | |||||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>(); | std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>(); | ||||
| @@ -70,6 +76,6 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | ||||
| TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser()); | |||||
| TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,11 +33,6 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | |||||
| public: | |||||
| TfliteBatchToSpaceNDParser() : TfliteBatchToSpaceParser() {} | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||