| @@ -89,7 +89,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| } | |||
| int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<PrimitiveC> primitive, | |||
| const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||
| MS_ASSERT(meta_graph != nullptr); | |||
| MS_ASSERT(primitive != nullptr); | |||
| @@ -173,7 +173,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & | |||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *return_node) { | |||
| MS_ASSERT(nullptr != meta_graph); | |||
| MS_ASSERT(nullptr != meta_graphT); | |||
| MS_ASSERT(nullptr != return_node); | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| auto input_node = cnode->input(i); | |||
| @@ -191,8 +191,8 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < return_node->inputIndex.size(); ++i) { | |||
| meta_graphT->outputIndex.push_back(return_node->inputIndex[i]); | |||
| for (unsigned int &i : return_node->inputIndex) { | |||
| meta_graphT->outputIndex.push_back(i); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -272,7 +272,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||
| return meta_graphT.release(); | |||
| } | |||
| int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode) { | |||
| int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) { | |||
| std::string input_name = input_anode->fullname_with_scope(); | |||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | |||
| @@ -336,7 +336,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, | |||
| int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *output_cnode) { | |||
| auto paramNode = input_anode->cast<ParameterPtr>(); | |||
| @@ -382,7 +382,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *output_cnode) { | |||
| auto valueNode = input_anode->cast<ValueNodePtr>(); | |||
| @@ -478,7 +478,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node) { | |||
| MS_ASSERT(nullptr != meta_graph); | |||
| MS_ASSERT(nullptr != meta_graphT); | |||
| MS_ASSERT(nullptr != fb_node); | |||
| if (cnode->inputs().size() <= 1) { | |||
| return RET_OK; | |||
| @@ -518,14 +518,14 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch | |||
| void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node) { | |||
| MS_ASSERT(nullptr != graph); | |||
| MS_ASSERT(nullptr != meta_graphT); | |||
| MS_ASSERT(nullptr != fb_node); | |||
| std::string cnode_name = fb_node->name; | |||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| auto msTensor = new schema::TensorT(); | |||
| auto msTensor = new (std::nothrow) schema::TensorT(); | |||
| msTensor->nodeType = schema::NodeType_CNode; | |||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| #ifdef SUPPORT_TRAIN | |||
| @@ -552,7 +552,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| #endif | |||
| } | |||
| } else { | |||
| auto ms_tensor = new schema::TensorT(); | |||
| auto ms_tensor = new (std::nothrow) schema::TensorT(); | |||
| ms_tensor->nodeType = schema::NodeType_CNode; | |||
| ms_tensor->dataType = TypeId::kNumberTypeFloat32; | |||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #include <map> | |||
| #include <string> | |||
| @@ -36,21 +36,22 @@ class AnfExporter { | |||
| schema::CNodeT *fb_node); | |||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node); | |||
| void RemoveIfMakeTuple(const CNodePtr &cnode); | |||
| void RemoveIfDepend(const CNodePtr &cnode); | |||
| static void RemoveIfMakeTuple(const CNodePtr &cnode); | |||
| static void RemoveIfDepend(const CNodePtr &cnode); | |||
| protected: | |||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode); | |||
| int ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, | |||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | |||
| int ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *return_node); | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||
| int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<PrimitiveC> primitive, const std::unique_ptr<schema::CNodeT> &dst_node); | |||
| static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||
| static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||
| private: | |||
| std::map<std::string, int> node_id_map_; | |||
| @@ -62,4 +63,4 @@ class AnfExporter { | |||
| // and clear. | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| @@ -15,8 +15,6 @@ | |||
| */ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include "schema/model_generated.h" | |||
| #include "ir/dtype.h" | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #include <unordered_map> | |||
| #include "ir/func_graph.h" | |||
| @@ -51,4 +51,4 @@ class AnfImporter { | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| @@ -22,7 +22,6 @@ | |||
| #include "src/param_value_lite.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore::lite { | |||
| int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| @@ -31,11 +30,9 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | |||
| auto &tensor = meta_graph_->allTensors.at(i); | |||
| MS_ASSERT(tensor != nullptr); | |||
| // converter weight and graph input into parameter node | |||
| if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(tensor->dims() != nullptr); | |||
| auto parameter = func_graph_->add_parameter(); | |||
| std::vector<int> shape(tensor->dims.size()); | |||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||
| @@ -45,11 +42,12 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != abstract_tensor); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name("const_" + std::to_string(i) + "_parameter"); | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| MS_ASSERT(param_value != nullptr); | |||
| MS_ASSERT(nullptr != param_value); | |||
| param_value->set_tensor_shape(shape); | |||
| param_value->set_tensor_type(type_id); | |||
| param_value->set_format(tensor->format); | |||
| @@ -123,7 +121,9 @@ abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTe | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| auto ptr = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != ptr); | |||
| return ptr; | |||
| } | |||
| int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, | |||
| @@ -175,15 +175,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs = {anf_primitive}; | |||
| for (unsigned int j : cNode->inputIndex) { | |||
| for (int j : cNode->inputIndex) { | |||
| auto node = GetNode(j); | |||
| if (nullptr == node) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_ERROR; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op_inputs.push_back(node); | |||
| } | |||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||
| MS_ASSERT(nullptr != new_cnode); | |||
| new_cnode->set_fullname_with_scope(cNode->name); | |||
| auto status = ConvertAbstract(cNode, new_cnode); | |||
| if (status != RET_OK) { | |||
| @@ -195,10 +196,8 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| } | |||
| int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| if (meta_graph_ == nullptr || func_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "meta_graph or func_graph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_ASSERT(nullptr != meta_graph_); | |||
| MS_ASSERT(nullptr != func_graph_); | |||
| if (meta_graph_->outputIndex.size() > 1) { | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||
| @@ -229,6 +228,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| op_inputs.emplace_back(value_node); | |||
| op_inputs.emplace_back(make_tuple_cnode); | |||
| auto cnode = func_graph_->NewCNode(op_inputs); | |||
| MS_ASSERT(nullptr != cnode); | |||
| cnode->set_fullname_with_scope("return"); | |||
| func_graph_->set_return(cnode); | |||
| } else { | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||
| #include <utility> | |||
| #include <memory> | |||
| @@ -40,7 +40,9 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||
| int ConverterCNode() override; | |||
| ValueNodePtr ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode); | |||
| abstract::AbstractTensorPtr ConvertTensorToAbstractTensor(const std::unique_ptr<schema::TensorT> &tensor); | |||
| static abstract::AbstractTensorPtr ConvertTensorToAbstractTensor(const std::unique_ptr<schema::TensorT> &tensor); | |||
| int ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, const CNodePtr &dst_cnode); | |||
| int AddReturnCNode() override; | |||
| @@ -51,4 +53,4 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | |||
| @@ -239,7 +239,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node | |||
| node->set_abstract(abstract_tensor); | |||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | |||
| auto *tensor_info = new Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||
| auto *tensor_info = new (std::nothrow) Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||
| if (tensor_info == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| @@ -345,7 +345,6 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor | |||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||
| return {}; | |||
| } | |||
| return {}; | |||
| } | |||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| @@ -871,7 +870,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||
| } | |||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | |||
| auto onnx_model = new onnx::ModelProto; | |||
| auto onnx_model = new (std::nothrow) onnx::ModelProto; | |||
| if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #include <map> | |||
| #include <string> | |||
| @@ -81,4 +81,4 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| @@ -24,7 +24,6 @@ Option<std::string> FlagParser::ParseFlags(int argc, const char *const *argv, bo | |||
| bool supportDuplicate) { | |||
| MS_ASSERT(argv != nullptr); | |||
| const int FLAG_PREFIX_LEN = 2; | |||
| // Get binary name | |||
| binName = GetFileName(argv[0]); | |||
| std::multimap<std::string, Option<std::string>> keyValues; | |||
| @@ -45,9 +44,7 @@ Option<std::string> FlagParser::ParseFlags(int argc, const char *const *argv, bo | |||
| Option<std::string> value = Option<std::string>(None()); | |||
| size_t pos = flagItem.find_first_of('='); | |||
| if (pos == std::string::npos && flagItem.find("--no-") != std::string::npos) { | |||
| key = flagItem.substr(FLAG_PREFIX_LEN); | |||
| } else if (pos == std::string::npos) { | |||
| if (pos == std::string::npos) { | |||
| key = flagItem.substr(FLAG_PREFIX_LEN); | |||
| } else { | |||
| key = flagItem.substr(FLAG_PREFIX_LEN, pos - FLAG_PREFIX_LEN); | |||
| @@ -81,10 +78,10 @@ bool FlagParser::GetRealFlagName(std::string *flagName, const std::string &oriFl | |||
| // Inner parse function | |||
| Option<std::string> FlagParser::InnerParseFlags(std::multimap<std::string, Option<std::string>> *keyValues) { | |||
| MS_ASSERT(keyValues != nullptr); | |||
| for (auto it = keyValues->begin(); it != keyValues->end(); ++it) { | |||
| for (auto &keyValue : *keyValues) { | |||
| std::string flagName; | |||
| bool opaque = GetRealFlagName(&flagName, (*it).first); | |||
| Option<std::string> flagValue = (*it).second; | |||
| bool opaque = GetRealFlagName(&flagName, keyValue.first); | |||
| Option<std::string> flagValue = keyValue.second; | |||
| auto item = flags.find(flagName); | |||
| if (item == flags.end()) { | |||
| @@ -133,7 +130,7 @@ Option<std::string> FlagParser::InnerParseFlags(std::multimap<std::string, Optio | |||
| return Option<std::string>(None()); | |||
| } | |||
| void Replaceall(std::string *str, const std::string &oldValue, const std::string &newValue) { | |||
| void ReplaceAll(std::string *str, const std::string &oldValue, const std::string &newValue) { | |||
| if (str == nullptr) { | |||
| MS_LOG(ERROR) << "Input str is nullptr"; | |||
| return; | |||
| @@ -153,9 +150,9 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const { | |||
| std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; | |||
| // usage of bin name | |||
| usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | |||
| // help line of help message, usageLine:message of parametors | |||
| std::string helpLine = ""; | |||
| std::string usageLine = ""; | |||
| // help line of help message, usageLine:message of parameters | |||
| std::string helpLine; | |||
| std::string usageLine; | |||
| uint32_t i = 0; | |||
| for (auto flag = flags.begin(); flag != flags.end(); flag++) { | |||
| std::string flagName = flag->second.flagName; | |||
| @@ -165,7 +162,7 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const { | |||
| if (++i <= flags.size()) { | |||
| // add parameter help message of each line | |||
| thisLine += " " + helpInfo; | |||
| Replaceall(&helpInfo, "\n\r", "\n"); | |||
| ReplaceAll(&helpInfo, "\n\r", "\n"); | |||
| usageLine += thisLine + "\n"; | |||
| } else { | |||
| // breif help message | |||
| @@ -14,21 +14,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PREDICT_COMMON_FLAG_PARSER_H_ | |||
| #define PREDICT_COMMON_FLAG_PARSER_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <string> | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/option.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| struct FlagInfo; | |||
| struct Nothing {}; | |||
| class FlagParser { | |||
| @@ -44,6 +41,7 @@ class FlagParser { | |||
| template <typename Flags, typename T1, typename T2> | |||
| void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); | |||
| template <typename Flags, typename T1, typename T2> | |||
| void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); | |||
| @@ -94,7 +92,7 @@ class FlagParser { | |||
| Option<std::string> InnerParseFlags(std::multimap<std::string, Option<std::string>> *values); | |||
| bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName); | |||
| static bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName); | |||
| std::map<std::string, FlagInfo> flags; | |||
| }; | |||
| @@ -181,7 +179,7 @@ void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string | |||
| FlagInfo flagItem; | |||
| // flagItem is as a output parameter | |||
| // flagItem is as an output parameter | |||
| ConstructFlag(t1, flagName, helpInfo, flagItem); | |||
| flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> { | |||
| if (base != nullptr) { | |||
| @@ -301,4 +299,4 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // PREDICT_COMMON_FLAG_PARSER_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H | |||
| @@ -15,8 +15,7 @@ | |||
| */ | |||
| #include "tools/common/graph_util.h" | |||
| #include <stdlib.h> | |||
| #include <time.h> | |||
| #include <ctime> | |||
| #include <utility> | |||
| #include <set> | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -29,7 +28,10 @@ namespace mindspore { | |||
| namespace lite { | |||
| OpDefCopyer GetSimpleOpCopyer() { | |||
| return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> { | |||
| std::unique_ptr<CNodeT> newCNode(new CNodeT); | |||
| std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>(); | |||
| if (newCNode == nullptr) { | |||
| return nullptr; | |||
| } | |||
| newCNode->name = inCNode->name; | |||
| newCNode->quantType = inCNode->quantType; | |||
| @@ -163,8 +165,6 @@ STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) { | |||
| } | |||
| } | |||
| // whether need to remove weightInputTensores | |||
| // remove all node's outputTensors | |||
| RemoveTensor(graphT, outputTensorIdxes); | |||
| node->inputIndex.clear(); | |||
| node->outputIndex.clear(); | |||
| @@ -183,8 +183,11 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool remove | |||
| MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| CNodeT *node = graphT->nodes.at(nodeIdx).get(); | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "node is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto inputTensorIdxes = node->inputIndex; | |||
| auto outputTensorIdxes = node->outputIndex; | |||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); | |||
| @@ -244,6 +247,7 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTe | |||
| size_t nodeIdx = 0; | |||
| for (size_t i = 0; i < graphT->nodes.size(); i++) { | |||
| auto &inNode = graphT->nodes.at(i); | |||
| MS_ASSERT(inNode != nullptr); | |||
| if (inNode->name == node->name) { | |||
| isSubNode = true; | |||
| nodeIdx = i; | |||
| @@ -259,6 +263,7 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTe | |||
| } | |||
| STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { | |||
| uint32_t deleteIdx = *iter; | |||
| if (!forceDelete) { | |||
| @@ -297,6 +302,7 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe | |||
| } | |||
| STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) { | |||
| MS_ASSERT(node != nullptr); | |||
| for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { | |||
| if (*inIdxIt == deleteIdx) { | |||
| inIdxIt = node->inputIndex.erase(inIdxIt); | |||
| @@ -330,6 +336,7 @@ STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ | |||
| graphT->allTensors.emplace_back(std::move(tensor)); | |||
| uint32_t newTensorIdx = graphT->allTensors.size() - 1; | |||
| auto node = graphT->nodes.at(nodeIdx).get(); | |||
| MS_ASSERT(node != nullptr); | |||
| if (place == kBefore) { | |||
| node->inputIndex.emplace_back(newTensorIdx); | |||
| } else { | |||
| @@ -340,11 +347,13 @@ STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ | |||
| STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, | |||
| std::unique_ptr<TensorT> tensor) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| if (nodeIdx >= graphT->nodes.size()) { | |||
| MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto node = graphT->nodes.at(nodeIdx).get(); | |||
| MS_ASSERT(node != nullptr); | |||
| if (inTensorIdx >= graphT->allTensors.size()) { | |||
| MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx; | |||
| return RET_PARAM_INVALID; | |||
| @@ -358,7 +367,9 @@ STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_ | |||
| } | |||
| NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, | |||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(errorCode != nullptr); | |||
| if (existNodeIdx >= graphT->nodes.size()) { | |||
| MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx; | |||
| return graphT->nodes.end(); | |||
| @@ -370,7 +381,9 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPla | |||
| } | |||
| NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | |||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||
| std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(errorCode != nullptr); | |||
| if (place == kBefore) { | |||
| return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); | |||
| } else if (place == kAfter) { | |||
| @@ -382,7 +395,9 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPl | |||
| } | |||
| NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, | |||
| std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||
| std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, const OpDefCopyer &opDefCopyer) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(errorCode != nullptr); | |||
| auto &existNode = *existNodeIter; | |||
| MS_ASSERT(existNode != nullptr); | |||
| MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx); | |||
| @@ -390,7 +405,7 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx); | |||
| MS_ASSERT(graphT->allTensors.size() > preTensorIdx); | |||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx); | |||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx); | |||
| if (preNodeIdxes.empty()) { | |||
| auto &preTensor = graphT->allTensors.at(preTensorIdx); | |||
| MS_ASSERT(preTensor != nullptr); | |||
| @@ -402,9 +417,12 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| } | |||
| preTensor->refCount = 0; | |||
| preTensor->data.clear(); | |||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||
| MS_ASSERT(prim != nullptr); | |||
| preTensor->dataType = prim->srcT; | |||
| toAddTensor->dataType = prim->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| @@ -438,9 +456,12 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| MS_LOG(ERROR) << "Copy TensorT failed"; | |||
| return graphT->nodes.end(); | |||
| } | |||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||
| MS_ASSERT(prim != nullptr); | |||
| preTensor->dataType = prim->srcT; | |||
| toAddTensor->dataType = prim->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| @@ -473,7 +494,10 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| } | |||
| NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, | |||
| std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { | |||
| std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, | |||
| const OpDefCopyer &opDefCopyer) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(errorCode != nullptr); | |||
| auto &existNode = *existNodeIter; | |||
| MS_ASSERT(existNode != nullptr); | |||
| MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx); | |||
| @@ -481,7 +505,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx); | |||
| MS_ASSERT(graphT->allTensors.size() > postTensorIdx); | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx); | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx); | |||
| if (postNodeIdxes.empty()) { | |||
| auto &postTensor = graphT->allTensors.at(postTensorIdx); | |||
| MS_ASSERT(postTensor != nullptr); | |||
| @@ -491,9 +515,12 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| *errorCode = RET_NULL_PTR; | |||
| return graphT->nodes.end(); | |||
| } | |||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||
| MS_ASSERT(prim != nullptr); | |||
| postTensor->dataType = prim->srcT; | |||
| toAddTensor->dataType = prim->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| @@ -554,9 +581,12 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| *errorCode = RET_NULL_PTR; | |||
| return graphT->nodes.end(); | |||
| } | |||
| MS_ASSERT(toAddNodeIn->primitive != nullptr); | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||
| MS_ASSERT(prim != nullptr); | |||
| postTensor->dataType = prim->srcT; | |||
| toAddTensor->dataType = prim->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| @@ -589,13 +619,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| return existNodeIter; | |||
| } | |||
| STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) { | |||
| if (modelFile.size() > fileType.size()) { | |||
| if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) { | |||
| return RET_OK; | |||
| } else { | |||
| return RET_ERROR; | |||
| } | |||
| STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) { | |||
| if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) { | |||
| return RET_OK; | |||
| } else { | |||
| return RET_ERROR; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_GRAPH_UTIL_H | |||
| #define MINDSPORE_PREDICT_GRAPH_UTIL_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||
| #include <cstdlib> | |||
| #include <unordered_map> | |||
| @@ -23,7 +23,6 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/common/graph_util.h" | |||
| @@ -73,19 +72,19 @@ STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_ | |||
| NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, | |||
| OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); | |||
| const OpDefCopyer &opDefCopyer = GetSimpleOpCopyer()); | |||
| NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, | |||
| OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); | |||
| const OpDefCopyer &opDefCopyer = GetSimpleOpCopyer()); | |||
| NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer); | |||
| NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer); | |||
| STATUS ValidateFileStr(const std::string &modelFile, std::string fileType); | |||
| STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType); | |||
| void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); | |||
| @@ -97,4 +96,4 @@ std::string GetModelName(const std::string &modelFile); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_GRAPH_UTIL_H | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||
| @@ -160,6 +160,7 @@ std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; } | |||
| STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims, | |||
| mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) { | |||
| MS_ASSERT(nullptr != dst_dims); | |||
| if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { | |||
| MS_LOG(ERROR) << "Convert format , src size " << src_dims.size() | |||
| << " <3 or src format is equal to dst format,not need convert"; | |||
| @@ -189,7 +190,7 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v | |||
| return RET_ERROR; | |||
| } | |||
| if (nchw_dim.size() == 0) { | |||
| if (nchw_dim.empty()) { | |||
| MS_LOG(ERROR) << "Param nchw_dim is empty!"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -215,6 +216,10 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v | |||
| STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | |||
| int32_t *filterH, int32_t *filterW) { | |||
| if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) { | |||
| MS_LOG(ERROR) << "null input"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_ASSERT(oriDims.size() == 4); | |||
| if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { | |||
| *filterK = oriDims.at(KCHW_K); | |||
| @@ -282,6 +287,7 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt | |||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<int32_t> oriDims = tensor->dims; | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_NODE_UTIL_H | |||
| #define MINDSPORE_PREDICT_NODE_UTIL_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||
| #include <memory> | |||
| #include <vector> | |||
| @@ -60,13 +60,6 @@ class NodeUtils { | |||
| public: | |||
| static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format, | |||
| std::vector<int32_t> *dst_dims); | |||
| static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin, | |||
| int64_t out_dim, int64_t stride); | |||
| static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int32_t> &input_dims, | |||
| std::vector<int32_t> &begin, std::vector<int32_t> &output_dims, | |||
| schema::TensorT *output, std::vector<int32_t> &stride); | |||
| }; | |||
| enum kTransFilterType { | |||
| @@ -133,7 +126,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in | |||
| if (type == kCHWK2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCHWK2KHWC) { | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } | |||
| @@ -334,4 +327,4 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) | |||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_NODE_UTIL_H | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PREDICT_COMMON_OPTION_H_ | |||
| #define PREDICT_COMMON_OPTION_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_OPTION_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_OPTION_H | |||
| #include <type_traits> | |||
| #include <utility> | |||
| @@ -56,7 +56,7 @@ class Option { | |||
| } | |||
| } | |||
| virtual ~Option() {} | |||
| virtual ~Option() = default; | |||
| bool IsNone() const { return state == NONE; } | |||
| @@ -116,4 +116,4 @@ class Option { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // PREDICT_COMMON_OPTION_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_OPTION_H | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_PROTOBUF_UTILS_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_PROTOBUF_UTILS_H | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -35,4 +35,4 @@ STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *mess | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_PROTOBUF_UTILS_H | |||
| @@ -50,7 +50,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath | |||
| } | |||
| schema::MetaGraphT *Storage::Load(const std::string &inputPath) { | |||
| size_t size; | |||
| size_t size = 0; | |||
| auto buf = ReadFile(inputPath.c_str(), &size); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "the file buffer is nullptr"; | |||
| @@ -58,7 +58,7 @@ schema::MetaGraphT *Storage::Load(const std::string &inputPath) { | |||
| } | |||
| flatbuffers::Verifier verify((const uint8_t *)buf, size); | |||
| if (false == schema::VerifyMetaGraphBuffer(verify)) { | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph"; | |||
| return nullptr; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PREDICT_COMMON_STORAGE_H_ | |||
| #define PREDICT_COMMON_STORAGE_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H | |||
| #include <fstream> | |||
| #include <string> | |||
| @@ -27,11 +27,11 @@ namespace mindspore { | |||
| namespace lite { | |||
| class Storage { | |||
| public: | |||
| int Save(const schema::MetaGraphT &graph, const std::string &outputPath); | |||
| static int Save(const schema::MetaGraphT &graph, const std::string &outputPath); | |||
| schema::MetaGraphT *Load(const std::string &inputPath); | |||
| static schema::MetaGraphT *Load(const std::string &inputPath); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // PREDICT_COMMON_STORAGE_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H | |||
| @@ -14,7 +14,6 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <cfloat> | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/common/graph_util.h" | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_TENSOR_UTIL_H | |||
| #define MINDSPORE_PREDICT_TENSOR_UTIL_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H | |||
| #include <cmath> | |||
| #include <unordered_map> | |||
| @@ -58,13 +58,11 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem | |||
| std::unique_ptr<schema::QuantParamT> CopyQuantParamArrayT( | |||
| const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray); | |||
| using MSGraphDefTPtr = std::shared_ptr<schema::MetaGraphT>; | |||
| enum Category { CONST = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 }; | |||
| class TensorCache { | |||
| public: | |||
| TensorCache() {} | |||
| TensorCache() = default; | |||
| ~TensorCache() { tensors.clear(); } | |||
| @@ -97,12 +95,12 @@ class TensorCache { | |||
| return -1; | |||
| } | |||
| void UpdateTensorIndex(const std::string &name, int index) { | |||
| void UpdateTensorIndex(const std::string &name, int idx) { | |||
| auto iter = tensorIndex.find(name); | |||
| if (iter != tensorIndex.end()) { | |||
| tensorIndex[name] = index; | |||
| tensorIndex[name] = idx; | |||
| } else { | |||
| tensorIndex.insert(make_pair(name, index)); | |||
| tensorIndex.insert(make_pair(name, idx)); | |||
| } | |||
| } | |||
| @@ -120,4 +118,4 @@ class TensorCache { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_TENSOR_UTIL_H | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H | |||
| @@ -38,17 +38,16 @@ STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| return RET_NULL_PTR; | |||
| } | |||
| // set default params | |||
| attr->outMaxValue = false; | |||
| attr->topK = 1; | |||
| const caffe::ArgMaxParameter argmaxParam = proto.argmax_param(); | |||
| const caffe::ArgMaxParameter &argmaxParam = proto.argmax_param(); | |||
| if (argmaxParam.has_out_max_val()) { | |||
| attr->outMaxValue = argmaxParam.out_max_val(); | |||
| } | |||
| if (argmaxParam.has_top_k()) { | |||
| attr->topK = argmaxParam.top_k(); | |||
| } | |||
| int32_t axisType; | |||
| int32_t axisType = 0; | |||
| int32_t axis = 0; | |||
| if (!argmaxParam.has_axis()) { | |||
| axisType = 2; | |||
| @@ -26,7 +26,8 @@ namespace lite { | |||
| class CaffeArgMaxParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeArgMaxParser() : CaffeNodeParser("argmax") {} | |||
| ~CaffeArgMaxParser() = default; | |||
| ~CaffeArgMaxParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| }; | |||
| @@ -19,12 +19,6 @@ | |||
| #include <memory> | |||
| #include "tools/common/tensor_util.h" | |||
| #define CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT 0.00001 | |||
| #define CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT 0.000000001 | |||
| static const int CAFFE_BATCHNORMAL_BOTTOM_SIZE = 1; | |||
| static const int CAFFE_BATCHNORMAL_TOP_SIZE = 1; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| using STATUS = int; | |||
| @@ -32,6 +26,10 @@ using STATUS = int; | |||
| STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeBatchNormParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -48,43 +46,38 @@ STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caf | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::BatchNormParameter batchNormParam = proto.batch_norm_param(); | |||
| // check bottom size | |||
| if (proto.bottom_size() != CAFFE_BATCHNORMAL_BOTTOM_SIZE) { | |||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "bottom numbers is error, it must be " | |||
| << CAFFE_BATCHNORMAL_BOTTOM_SIZE << "but is " << proto.bottom_size(); | |||
| const caffe::BatchNormParameter &batchNormParam = proto.batch_norm_param(); | |||
| if (proto.bottom_size() != 1) { | |||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "bottom numbers is error, it must be 1, but is " | |||
| << proto.bottom_size(); | |||
| return RET_ERROR; | |||
| } | |||
| // check top size | |||
| if (proto.top_size() != CAFFE_BATCHNORMAL_TOP_SIZE) { | |||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "top numbers is error, it must be " | |||
| << CAFFE_BATCHNORMAL_TOP_SIZE << "but is " << proto.top_size(); | |||
| if (proto.top_size() != 1) { | |||
| MS_LOG(ERROR) << "Layer " << proto.name().c_str() << "top numbers is error, it must be 1, but is " | |||
| << proto.top_size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (batchNormParam.has_eps()) { | |||
| if (fabs(CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT - batchNormParam.eps()) < CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT) { | |||
| attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; | |||
| if (std::fabs(1e-5 - batchNormParam.eps()) < 1e-9) { | |||
| attr->epsilon = 1e-5; | |||
| } else { | |||
| auto tmpAuto = batchNormParam.eps(); | |||
| attr->epsilon = tmpAuto; | |||
| } | |||
| } else { | |||
| attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; | |||
| attr->epsilon = 1e-5; | |||
| } | |||
| const float blob2Data = | |||
| (weight.blobs(2).double_data_size() > 0) ? weight.blobs(2).double_data(0) : weight.blobs(2).data(0); | |||
| const float scaleFactor = blob2Data == 0 ? 0 : 1 / blob2Data; | |||
| // parse weight gamma | |||
| auto gamma = ConvertWeight(weight.blobs(0)); | |||
| if (gamma == nullptr) { | |||
| MS_LOG(ERROR) << "Convert blobs(0) for layer " << weight.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto estimatedMean = reinterpret_cast<float *>(gamma->data.data()); | |||
| auto estimatedMeanShapeSize = GetShapeSize(*gamma); | |||
| for (size_t i = 0; i < estimatedMeanShapeSize; i++) { | |||
| @@ -93,13 +86,11 @@ STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caf | |||
| estimatedMean = nullptr; | |||
| weightVec->push_back(gamma); | |||
| // parse weight beta | |||
| auto beta = ConvertWeight(weight.blobs(1)); | |||
| if (beta == nullptr) { | |||
| MS_LOG(ERROR) << "Convert blobs(1) for layer " << weight.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto estimatedVariance = reinterpret_cast<float *>(beta->data.data()); | |||
| size_t estimatedVarianceShapeSize = GetShapeSize(*beta); | |||
| for (size_t i = 0; i < estimatedVarianceShapeSize; i++) { | |||
| @@ -26,6 +26,7 @@ namespace lite { | |||
| class CaffeBatchNormParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeBatchNormParser() : CaffeNodeParser("batchnorm") {} | |||
| ~CaffeBatchNormParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -17,8 +17,6 @@ | |||
| #include "tools/converter/parser/caffe/caffe_concat_parser.h" | |||
| #include <memory> | |||
| const int32_t CONCAT_DEFAULT_AXIS = 1; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| @@ -40,7 +38,7 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::ConcatParameter concatParam = proto.concat_param(); | |||
| const caffe::ConcatParameter &concatParam = proto.concat_param(); | |||
| if (concatParam.has_axis() && concatParam.has_concat_dim()) { | |||
| MS_LOG(ERROR) << "Concat param in caffe have concat_dim and axis simultaneously, return fail"; | |||
| return RET_ERROR; | |||
| @@ -48,19 +46,19 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| if (concatParam.has_concat_dim()) { | |||
| MS_LOG(DEBUG) << "Concat dim , set axis: " << concatParam.concat_dim(); | |||
| int32_t concat_dim_value = (int32_t)concatParam.concat_dim(); | |||
| auto concat_dim_value = (int32_t)concatParam.concat_dim(); | |||
| if (concat_dim_value < 0) { | |||
| MS_LOG(ERROR) << "concat_dim value in model is smaller than 0:" << concat_dim_value; | |||
| return RET_ERROR; | |||
| } | |||
| attr->axis = concat_dim_value; | |||
| } else if (concatParam.has_axis()) { | |||
| MS_LOG(DEBUG) << "axis , set axis: " << concatParam.axis(); | |||
| int32_t tmpInt = (int32_t)concatParam.axis(); | |||
| MS_LOG(DEBUG) << "set axis: " << concatParam.axis(); | |||
| auto tmpInt = (int32_t)concatParam.axis(); | |||
| attr->axis = tmpInt; | |||
| } else { | |||
| MS_LOG(DEBUG) << "default , set axis: " << CONCAT_DEFAULT_AXIS; | |||
| attr->axis = CONCAT_DEFAULT_AXIS; | |||
| MS_LOG(DEBUG) << "by default, set axis = 1"; | |||
| attr->axis = 1; | |||
| } | |||
| attr->n = proto.bottom_size(); | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeConcatParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeConcatParser() : CaffeNodeParser("concat") {} | |||
| ~CaffeConcatParser() = default; | |||
| ~CaffeConcatParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -17,13 +17,6 @@ | |||
| #include "tools/converter/parser/caffe/caffe_conv_base_parser.h" | |||
| #include <algorithm> | |||
| const uint32_t PAD_DEFAULT_VALUE = 0; | |||
| const uint32_t STRIDE_DEFAULT_VALUE = 1; | |||
| const uint32_t DILATION_DEFAULT_VALUE = 1; | |||
| const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; | |||
| const uint32_t DEFAULT_CONV_GROUP = 1; | |||
| static const int CAFFE_CONV_BIAS_DIM_NUM = 1; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convParam, std::vector<int64_t> *pad) { | |||
| @@ -40,15 +33,15 @@ STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convPar | |||
| } | |||
| if (!convParam.has_pad_h()) { | |||
| (*pad)[0] = PAD_DEFAULT_VALUE; | |||
| (*pad)[1] = PAD_DEFAULT_VALUE; | |||
| (*pad)[0] = 0; | |||
| (*pad)[1] = 0; | |||
| (*pad)[2] = convParam.pad_w(); | |||
| (*pad)[3] = convParam.pad_w(); | |||
| } else if (!convParam.has_pad_w()) { | |||
| (*pad)[0] = convParam.pad_h(); | |||
| (*pad)[1] = convParam.pad_h(); | |||
| (*pad)[2] = PAD_DEFAULT_VALUE; | |||
| (*pad)[3] = PAD_DEFAULT_VALUE; | |||
| (*pad)[2] = 0; | |||
| (*pad)[3] = 0; | |||
| } else { | |||
| (*pad)[0] = convParam.pad_h(); | |||
| (*pad)[1] = convParam.pad_h(); | |||
| @@ -56,15 +49,14 @@ STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convPar | |||
| (*pad)[3] = convParam.pad_w(); | |||
| } | |||
| } else { | |||
| // default 2D | |||
| const int num_pad_dims = convParam.pad_size(); | |||
| int num_spatial_dims = std::max(num_pad_dims, SPATIAL_DIM_DEFAULT_SIZE); | |||
| int num_spatial_dims = std::max(num_pad_dims, 2); | |||
| std::vector<int64_t> vec; | |||
| vec.reserve(num_spatial_dims); | |||
| for (int i = 0; i < num_spatial_dims; ++i) { | |||
| vec.push_back((num_pad_dims == 0) ? PAD_DEFAULT_VALUE : convParam.pad((num_pad_dims == 1) ? 0 : i)); | |||
| vec.push_back((num_pad_dims == 0) ? 0 : convParam.pad((num_pad_dims == 1) ? 0 : i)); | |||
| } | |||
| // default 2D | |||
| (*pad)[0] = vec[0]; | |||
| (*pad)[1] = vec[0]; | |||
| (*pad)[2] = vec[1]; | |||
| @@ -87,13 +79,13 @@ STATUS CaffeConvBaseParser::ParseStrides(const caffe::ConvolutionParameter &conv | |||
| (*stride)[1] = convParam.stride_w(); | |||
| } else { | |||
| const int num_stride_dims = convParam.stride_size(); | |||
| int num_spatial_dims = std::max(num_stride_dims, SPATIAL_DIM_DEFAULT_SIZE); | |||
| int num_spatial_dims = std::max(num_stride_dims, 2); | |||
| std::vector<int64_t> vec; | |||
| vec.reserve(num_spatial_dims); | |||
| for (int i = 0; i < num_spatial_dims; ++i) { | |||
| vec.push_back((num_stride_dims == 0) ? STRIDE_DEFAULT_VALUE : convParam.stride((num_stride_dims == 1) ? 0 : i)); | |||
| vec.push_back((num_stride_dims == 0) ? 1 : convParam.stride((num_stride_dims == 1) ? 0 : i)); | |||
| } | |||
| // default 2D | |||
| (*stride)[0] = vec[0]; | |||
| (*stride)[1] = vec[1]; | |||
| } | |||
| @@ -103,17 +95,15 @@ STATUS CaffeConvBaseParser::ParseStrides(const caffe::ConvolutionParameter &conv | |||
| STATUS CaffeConvBaseParser::ParseDilations(const caffe::ConvolutionParameter &convParam, | |||
| std::vector<int64_t> *dilation) { | |||
| const int num_dilation_dims = convParam.dilation_size(); | |||
| int num_spatial_dims = std::max(num_dilation_dims, SPATIAL_DIM_DEFAULT_SIZE); | |||
| int num_spatial_dims = std::max(num_dilation_dims, 2); | |||
| std::vector<int64_t> vec; | |||
| vec.reserve(num_spatial_dims); | |||
| for (int i = 0; i < num_spatial_dims; ++i) { | |||
| vec.push_back((num_dilation_dims == 0) ? DILATION_DEFAULT_VALUE | |||
| : convParam.dilation((num_dilation_dims == 1) ? 0 : i)); | |||
| vec.push_back((num_dilation_dims == 0) ? 1 : convParam.dilation((num_dilation_dims == 1) ? 0 : i)); | |||
| } | |||
| // default 2D | |||
| (*dilation)[0] = vec[0]; | |||
| (*dilation)[1] = vec[1]; | |||
| return RET_OK; | |||
| } | |||
| @@ -131,9 +121,11 @@ STATUS CaffeConvBaseParser::ParseKernels(const caffe::ConvolutionParameter &conv | |||
| return RET_ERROR; | |||
| } | |||
| } else if (convParam.kernel_size_size() != 0) { | |||
| int kernel_size = convParam.kernel_size_size(); | |||
| int num_spatial_dims = std::max(kernel_size, SPATIAL_DIM_DEFAULT_SIZE); | |||
| const int kernel_size = convParam.kernel_size_size(); | |||
| int num_spatial_dims = std::max(kernel_size, 2); | |||
| std::vector<int64_t> vec; | |||
| vec.reserve(num_spatial_dims); | |||
| for (int i = 0; i < num_spatial_dims; i++) { | |||
| vec.push_back(convParam.kernel_size((kernel_size == 1) ? 0 : i)); | |||
| } | |||
| @@ -141,24 +133,25 @@ STATUS CaffeConvBaseParser::ParseKernels(const caffe::ConvolutionParameter &conv | |||
| (*kernel)[0] = vec[0]; | |||
| (*kernel)[1] = vec[1]; | |||
| } else { | |||
| MS_LOG(ERROR) << "conv does not have kernel info."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CaffeConvBaseParser::ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType) { | |||
| // group default 1 | |||
| int group = 0; | |||
| if (convParam.has_group()) { | |||
| group = convParam.group(); | |||
| return convParam.group(); | |||
| } else { | |||
| layerType == "ConvolutionDepthwise" ? (group = convParam.num_output()) : (group = DEFAULT_CONV_GROUP); | |||
| return layerType == "ConvolutionDepthwise" ? static_cast<int>(convParam.num_output()) : 1; | |||
| } | |||
| return group; | |||
| } | |||
| int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut) { | |||
| MS_ASSERT(channelOut != nullptr); | |||
| if (channelOut == nullptr) { | |||
| MS_LOG(ERROR) << "channelOut is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (!convParam.has_num_output()) { | |||
| MS_LOG(ERROR) << "Parse num_output for failed."; | |||
| return RET_ERROR; | |||
| @@ -169,7 +162,11 @@ int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &conv | |||
| STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | |||
| std::vector<schema::TensorT *> *weightVec) { | |||
| // Layer must have Filter | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (weight.blobs_size() == 0) { | |||
| MS_LOG(ERROR) << "No filter data in layer " << weight.name().c_str(); | |||
| return RET_ERROR; | |||
| @@ -182,8 +179,7 @@ STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | |||
| } | |||
| weightVec->push_back(filter); | |||
| // parse bias | |||
| const caffe::ConvolutionParameter convParam = weight.convolution_param(); | |||
| const caffe::ConvolutionParameter &convParam = weight.convolution_param(); | |||
| if (convParam.bias_term() && weight.blobs_size() > 1) { | |||
| auto bias = ConvertWeight(weight.blobs(1)); | |||
| if (bias == nullptr) { | |||
| @@ -192,7 +188,7 @@ STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, | |||
| } | |||
| std::vector<int32_t> shape = bias->dims; | |||
| if (shape.size() != CAFFE_CONV_BIAS_DIM_NUM) { | |||
| if (shape.size() != 1) { | |||
| MS_LOG(ERROR) << "Bias dim-num of layer " << weight.name().c_str() << " is not supported"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -26,23 +26,23 @@ namespace mindspore { | |||
| namespace lite { | |||
| class CaffeConvBaseParser { | |||
| public: | |||
| CaffeConvBaseParser() {} | |||
| CaffeConvBaseParser() = default; | |||
| virtual ~CaffeConvBaseParser() {} | |||
| virtual ~CaffeConvBaseParser() = default; | |||
| STATUS ParsePads(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *pad); | |||
| static STATUS ParsePads(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *pad); | |||
| STATUS ParseStrides(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *stride); | |||
| static STATUS ParseStrides(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *stride); | |||
| STATUS ParseDilations(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *dilation); | |||
| static STATUS ParseDilations(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *dilation); | |||
| STATUS ParseKernels(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *kernel); | |||
| static STATUS ParseKernels(const caffe::ConvolutionParameter &conv_param, std::vector<int64_t> *kernel); | |||
| int ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType); | |||
| static int ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType); | |||
| int ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut); | |||
| static int ParseChannelOut(const caffe::ConvolutionParameter &convParam, int32_t *channelOut); | |||
| STATUS ParseWeight(const caffe::LayerParameter &weight, std::vector<schema::TensorT *> *weightVec); | |||
| static STATUS ParseWeight(const caffe::LayerParameter &weight, std::vector<schema::TensorT *> *weightVec); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -54,7 +54,10 @@ STATUS CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema: | |||
| STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeConvolutionParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -73,11 +76,10 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| attr->format = schema::Format_NCHW; | |||
| const caffe::ConvolutionParameter convParam = proto.convolution_param(); | |||
| CaffeConvBaseParser convParser; | |||
| const caffe::ConvolutionParameter &convParam = proto.convolution_param(); | |||
| // parse pad | |||
| std::vector<int64_t> pad(4, 0); | |||
| auto status = convParser.ParsePads(convParam, &pad); | |||
| auto status = CaffeConvBaseParser::ParsePads(convParam, &pad); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -89,7 +91,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| // parse stride | |||
| std::vector<int64_t> stride(2, 0); | |||
| status = convParser.ParseStrides(convParam, &stride); | |||
| status = CaffeConvBaseParser::ParseStrides(convParam, &stride); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -99,7 +101,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| // parse dilation | |||
| std::vector<int64_t> dilation(2, 0); | |||
| status = convParser.ParseDilations(convParam, &dilation); | |||
| status = CaffeConvBaseParser::ParseDilations(convParam, &dilation); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -109,7 +111,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| // parse kernel | |||
| std::vector<int64_t> kernel(2, 0); | |||
| status = convParser.ParseKernels(convParam, &kernel); | |||
| status = CaffeConvBaseParser::ParseKernels(convParam, &kernel); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -118,8 +120,8 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| attr->kernelW = kernel[1]; | |||
| attr->hasBias = convParam.bias_term(); | |||
| attr->group = convParser.ParseGroup(convParam, proto.type()); | |||
| auto ret = convParser.ParseChannelOut(convParam, &(attr->channelOut)); | |||
| attr->group = CaffeConvBaseParser::ParseGroup(convParam, proto.type()); | |||
| auto ret = CaffeConvBaseParser::ParseChannelOut(convParam, &(attr->channelOut)); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "conv channel out failed"; | |||
| return RET_ERROR; | |||
| @@ -128,7 +130,6 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| if (weightBlob.has_shape()) { | |||
| attr->channelIn = weightBlob.shape().dim(1) * attr->group; | |||
| } else { | |||
| // get shape information from Blob parameters(caffe proto v1) | |||
| attr->channelIn = weightBlob.channels() * attr->group; | |||
| } | |||
| attr->padMode = schema::PadMode_CAFFE; | |||
| @@ -143,7 +144,7 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c | |||
| return RET_ERROR; | |||
| } | |||
| status = convParser.ParseWeight(weight, weightVec); | |||
| status = CaffeConvBaseParser::ParseWeight(weight, weightVec); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -27,13 +27,13 @@ namespace lite { | |||
| class CaffeConvolutionParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeConvolutionParser() : CaffeNodeParser("convolution") {} | |||
| ~CaffeConvolutionParser() = default; | |||
| ~CaffeConvolutionParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| private: | |||
| STATUS ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); | |||
| static STATUS ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,13 +17,15 @@ | |||
| #include "tools/converter/parser/caffe/caffe_crop_parser.h" | |||
| #include <memory> | |||
| const int32_t CROP_AXIS = 2; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeCropParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -41,22 +43,23 @@ STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::L | |||
| } | |||
| if (!proto.has_crop_param()) { | |||
| attr->axis = CROP_AXIS; | |||
| attr->axis = 2; | |||
| std::vector<int64_t> offsets(2, 0); | |||
| attr->offsets = offsets; | |||
| } else { | |||
| const caffe::CropParameter cropParam = proto.crop_param(); | |||
| const caffe::CropParameter &cropParam = proto.crop_param(); | |||
| if (cropParam.has_axis()) { | |||
| if (cropParam.axis() == -1) { | |||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| attr->axis = cropParam.axis(); | |||
| } else { | |||
| attr->axis = CROP_AXIS; | |||
| attr->axis = 2; | |||
| } | |||
| if (cropParam.offset_size() != 0) { | |||
| std::vector<int64_t> offsets; | |||
| offsets.reserve(cropParam.offset_size()); | |||
| for (int i = 0; i < cropParam.offset_size(); i++) { | |||
| offsets.push_back(cropParam.offset(i)); | |||
| } | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeCropParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeCropParser() : CaffeNodeParser("crop") {} | |||
| ~CaffeCropParser() = default; | |||
| ~CaffeCropParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -54,7 +54,10 @@ STATUS CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, sch | |||
| STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeDeconvolutionParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -69,11 +72,10 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| attr->format = schema::Format::Format_NCHW; | |||
| const caffe::ConvolutionParameter convParam = proto.convolution_param(); | |||
| CaffeConvBaseParser convParser; | |||
| const caffe::ConvolutionParameter &convParam = proto.convolution_param(); | |||
| // parse pad | |||
| std::vector<int64_t> pad(4, 0); | |||
| auto status = convParser.ParsePads(convParam, &pad); | |||
| auto status = CaffeConvBaseParser::ParsePads(convParam, &pad); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -85,7 +87,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| // parse stride | |||
| std::vector<int64_t> stride(2, 0); | |||
| status = convParser.ParseStrides(convParam, &stride); | |||
| status = CaffeConvBaseParser::ParseStrides(convParam, &stride); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -95,7 +97,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| // parse dilation | |||
| std::vector<int64_t> dilation(2, 0); | |||
| status = convParser.ParseDilations(convParam, &dilation); | |||
| status = CaffeConvBaseParser::ParseDilations(convParam, &dilation); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -105,7 +107,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| // parse kernel | |||
| std::vector<int64_t> kernel(2, 0); | |||
| status = convParser.ParseKernels(convParam, &kernel); | |||
| status = CaffeConvBaseParser::ParseKernels(convParam, &kernel); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -114,8 +116,8 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| attr->kernelW = kernel[1]; | |||
| attr->hasBias = convParam.bias_term(); | |||
| attr->group = convParser.ParseGroup(convParam, proto.type()); | |||
| auto ret = convParser.ParseChannelOut(convParam, &(attr->channelOut)); | |||
| attr->group = CaffeConvBaseParser::ParseGroup(convParam, proto.type()); | |||
| auto ret = CaffeConvBaseParser::ParseChannelOut(convParam, &(attr->channelOut)); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "deconv channel get failed"; | |||
| return RET_ERROR; | |||
| @@ -127,7 +129,6 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| else | |||
| attr->channelIn = weightBlob.shape().dim(1) * attr->group; | |||
| } else { | |||
| // get shape information from Blob parameters(caffe proto v1) | |||
| attr->channelIn = weightBlob.num() * attr->group; | |||
| } | |||
| attr->padMode = schema::PadMode_CAFFE; | |||
| @@ -142,7 +143,7 @@ STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const | |||
| return RET_ERROR; | |||
| } | |||
| status = convParser.ParseWeight(weight, weightVec); | |||
| status = CaffeConvBaseParser::ParseWeight(weight, weightVec); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -27,13 +27,13 @@ namespace lite { | |||
| class CaffeDeconvolutionParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeDeconvolutionParser() : CaffeNodeParser("deconvolution") {} | |||
| ~CaffeDeconvolutionParser() = default; | |||
| ~CaffeDeconvolutionParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| private: | |||
| STATUS ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr); | |||
| static STATUS ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,9 +18,6 @@ | |||
| #include <cmath> | |||
| #include <memory> | |||
| const int ELTWISE_MIN_INPUT_SIZE = 2; | |||
| const float ELTWISE_SUM_COEFF_EPSILON = 1e-5; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| @@ -42,13 +39,13 @@ STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (proto.bottom_size() < ELTWISE_MIN_INPUT_SIZE) { | |||
| if (proto.bottom_size() < 2) { | |||
| MS_LOG(ERROR) << "Eltwise Op " << proto.name() << " need at least 2 inputs,but input size is " | |||
| << proto.bottom_size(); | |||
| return RET_ERROR; | |||
| } | |||
| const caffe::EltwiseParameter eltwiseParam = proto.eltwise_param(); | |||
| const caffe::EltwiseParameter &eltwiseParam = proto.eltwise_param(); | |||
| if (eltwiseParam.coeff_size() != 0 && eltwiseParam.coeff_size() != proto.bottom_size()) { | |||
| MS_LOG(ERROR) << "Coeff size(" << eltwiseParam.coeff_size() | |||
| << ") check fail, Eltwise Layer takes one coefficient per bottom blob."; | |||
| @@ -60,8 +57,8 @@ STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| return RET_ERROR; | |||
| } | |||
| if (eltwiseParam.coeff_size() != 0 && (fabs(eltwiseParam.coeff(0) - 1) > ELTWISE_SUM_COEFF_EPSILON || | |||
| fabs(eltwiseParam.coeff(1) - 1) > ELTWISE_SUM_COEFF_EPSILON)) { | |||
| if (eltwiseParam.coeff_size() != 0 && | |||
| (std::fabs(eltwiseParam.coeff(0) - 1) > 1e-5 || std::fabs(eltwiseParam.coeff(1) - 1) > 1e-5)) { | |||
| MS_LOG(ERROR) << "Eltwise only support coefficient 1 for summation now."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeEltwiseParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeEltwiseParser() : CaffeNodeParser("eltwise") {} | |||
| ~CaffeEltwiseParser() = default; | |||
| ~CaffeEltwiseParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -39,7 +39,7 @@ STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::La | |||
| } | |||
| if (proto.has_elu_param()) { | |||
| const caffe::ELUParameter eluParameter = proto.elu_param(); | |||
| const caffe::ELUParameter &eluParameter = proto.elu_param(); | |||
| if (eluParameter.has_alpha()) { | |||
| attr->alpha = eluParameter.alpha(); | |||
| } | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeEluParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeEluParser() : CaffeNodeParser("elu") {} | |||
| ~CaffeEluParser() = default; | |||
| ~CaffeEluParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -39,7 +39,7 @@ STATUS CaffeExpParser::Parse(const caffe::LayerParameter &proto, const caffe::La | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::ExpParameter exp_param = proto.exp_param(); | |||
| const caffe::ExpParameter &exp_param = proto.exp_param(); | |||
| if (exp_param.has_base()) { | |||
| attr->base = exp_param.base(); | |||
| } else { | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeExpParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeExpParser() : CaffeNodeParser("exp") {} | |||
| ~CaffeExpParser() = default; | |||
| ~CaffeExpParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeFlattenParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeFlattenParser() : CaffeNodeParser("flatten") {} | |||
| ~CaffeFlattenParser() = default; | |||
| ~CaffeFlattenParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -22,6 +22,10 @@ namespace lite { | |||
| STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeInnerProductParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -38,7 +42,7 @@ STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::InnerProductParameter innerProductParam = proto.inner_product_param(); | |||
| const caffe::InnerProductParameter &innerProductParam = proto.inner_product_param(); | |||
| if (!innerProductParam.has_num_output()) { | |||
| MS_LOG(ERROR) << "InnerProduct Parse num_output for " << proto.name().c_str() << " failed."; | |||
| return RET_ERROR; | |||
| @@ -62,8 +66,6 @@ STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const | |||
| MS_LOG(ERROR) << "InnerProduct No filter data in layer " << weight.name().c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| // parse filter | |||
| auto filter = ConvertWeight(weight.blobs(0)); | |||
| if (filter == nullptr) { | |||
| MS_LOG(ERROR) << "InnerProduct parse weight for layer " << weight.name().c_str() << " failed"; | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeInnerProductParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeInnerProductParser() : CaffeNodeParser("innerproduct") {} | |||
| ~CaffeInnerProductParser() = default; | |||
| ~CaffeInnerProductParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -47,12 +47,12 @@ STATUS CaffeInspector::ParseInput() { | |||
| } | |||
| STATUS CaffeInspector::FindInputAndOutput() { | |||
| for (auto iter : layerBottoms) { | |||
| for (const auto &iter : layerBottoms) { | |||
| if (layerTops.find(iter) == layerTops.end()) { | |||
| graphInput.insert(iter); | |||
| } | |||
| } | |||
| for (auto iter : layerTops) { | |||
| for (const auto &iter : layerTops) { | |||
| if (layerBottoms.find(iter) == layerBottoms.end()) { | |||
| graphOutput.insert(iter); | |||
| } | |||
| @@ -62,7 +62,7 @@ STATUS CaffeInspector::FindInputAndOutput() { | |||
| STATUS CaffeInspector::SetTopsAndBottoms() { | |||
| for (int32_t i = 0; i < net.layer_size(); i++) { | |||
| caffe::LayerParameter &layer = const_cast<caffe::LayerParameter &>(net.layer(i)); | |||
| auto &layer = const_cast<caffe::LayerParameter &>(net.layer(i)); | |||
| if (layer.top_size() == 1 && layer.bottom_size() == 1 && layer.top(0) == layer.bottom(0)) { | |||
| continue; | |||
| } | |||
| @@ -38,7 +38,7 @@ STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::InterpParameter interpParam = proto.interp_param(); | |||
| const caffe::InterpParameter &interpParam = proto.interp_param(); | |||
| if (interpParam.has_height()) { | |||
| int64_t height = interpParam.height(); | |||
| if (height < 0) { | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeInterpParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeInterpParser() : CaffeNodeParser("Interp") {} | |||
| ~CaffeInterpParser() = default; | |||
| ~CaffeInterpParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -23,6 +23,11 @@ namespace mindspore { | |||
| namespace lite { | |||
| schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||
| std::unique_ptr<schema::TensorT> weight = std::make_unique<schema::TensorT>(); | |||
| if (weight == nullptr) { | |||
| MS_LOG(ERROR) << "new weight failed"; | |||
| return nullptr; | |||
| } | |||
| weight->format = schema::Format::Format_NCHW; | |||
| std::vector<int32_t> shapeVec; | |||
| ConvertShape(proto, &shapeVec); | |||
| @@ -32,8 +37,7 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||
| // cal Weight num | |||
| int count = 1; | |||
| for (size_t i = 0; i < shapeVec.size(); ++i) { | |||
| int dim = shapeVec[i]; | |||
| for (int dim : shapeVec) { | |||
| if (dim <= 0) { | |||
| MS_LOG(ERROR) << "Convert weight fail, Blob size invalid"; | |||
| return nullptr; | |||
| @@ -48,6 +52,7 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||
| // get weight | |||
| std::unique_ptr<float[]> buf = std::make_unique<float[]>(count); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "new weight buf failed"; | |||
| return nullptr; | |||
| } | |||
| if (proto.double_data_size() > 0) { | |||
| @@ -74,6 +79,7 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||
| << "blob.data_size:%d" << proto.data_size(); | |||
| return nullptr; | |||
| } | |||
| weight->data.resize(count * sizeof(float)); | |||
| const float *data_ptr = proto.data().data(); | |||
| if (data_ptr == nullptr) { | |||
| @@ -91,8 +97,12 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||
| } | |||
| STATUS ConvertShape(const caffe::BlobProto &proto, std::vector<int32_t> *shape) { | |||
| shape->clear(); | |||
| if (shape == nullptr) { | |||
| MS_LOG(ERROR) << "shape is null"; | |||
| return RET_ERROR; | |||
| } | |||
| shape->clear(); | |||
| if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { | |||
| shape->push_back(proto.num()); | |||
| shape->push_back(proto.channels()); | |||
| @@ -18,7 +18,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| CaffeNodeParserRegistry::CaffeNodeParserRegistry() {} | |||
| CaffeNodeParserRegistry::CaffeNodeParserRegistry() = default; | |||
| CaffeNodeParserRegistry::~CaffeNodeParserRegistry() { | |||
| for (auto ite : parsers) { | |||
| @@ -38,7 +38,7 @@ STATUS CaffePermuteParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::PermuteParameter permuteParam = proto.permute_param(); | |||
| const caffe::PermuteParameter &permuteParam = proto.permute_param(); | |||
| const int num_order_dims = permuteParam.order_size(); | |||
| attr->perm.resize(num_order_dims); | |||
| for (int i = 0; i < num_order_dims; ++i) { | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffePermuteParser : public CaffeNodeParser { | |||
| public: | |||
| CaffePermuteParser() : CaffeNodeParser("Permute") {} | |||
| ~CaffePermuteParser() = default; | |||
| ~CaffePermuteParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -17,9 +17,6 @@ | |||
| #include "tools/converter/parser/caffe/caffe_pooling_parser.h" | |||
| #include <memory> | |||
| const uint32_t INNERPRODUCT_WINDOW_DEFAULT_VALUE = 0; | |||
| const uint32_t INNERPRODUCT_PAD_DEFAULT_VALUE = 0; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| @@ -43,7 +40,7 @@ STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| attr->format = schema::Format::Format_NCHW; | |||
| const caffe::PoolingParameter poolingParam = proto.pooling_param(); | |||
| const caffe::PoolingParameter &poolingParam = proto.pooling_param(); | |||
| auto status = ParsePads(poolingParam, attr.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | |||
| @@ -68,15 +65,12 @@ STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| return RET_ERROR; | |||
| } | |||
| // default roundMode RoundMode_CEIL | |||
| attr->roundMode = schema::RoundMode_CEIL; | |||
| if (poolingParam.has_round_mode()) { | |||
| if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_FLOOR) { | |||
| attr->roundMode = schema::RoundMode_FLOOR; | |||
| } else if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_CEIL) { | |||
| attr->roundMode = schema::RoundMode_CEIL; | |||
| } else { | |||
| MS_ASSERT(false); | |||
| } | |||
| } | |||
| attr->padMode = schema::PadMode_CAFFE; | |||
| @@ -127,8 +121,8 @@ STATUS CaffePoolingParser::ParseWindows(const caffe::PoolingParameter &poolingPa | |||
| MS_LOG(ERROR) << "With Global_pooling: true Filter size cannot specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->windowH = INNERPRODUCT_WINDOW_DEFAULT_VALUE; | |||
| attr->windowW = INNERPRODUCT_WINDOW_DEFAULT_VALUE; | |||
| attr->windowH = 0; | |||
| attr->windowW = 0; | |||
| attr->global = true; | |||
| } else { | |||
| if (poolingParam.has_kernel_size() == (poolingParam.has_kernel_h() || poolingParam.has_kernel_w())) { | |||
| @@ -157,7 +151,7 @@ STATUS CaffePoolingParser::ParsePoolingMode(const caffe::PoolingParameter &pooli | |||
| } else if (poolingParam.pool() == caffe::PoolingParameter::AVE) { | |||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||
| } else { | |||
| MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."; | |||
| MS_LOG(ERROR) << "MindSpore support MAX and AVE PoolingMode only."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -26,18 +26,18 @@ namespace lite { | |||
| class CaffePoolingParser : public CaffeNodeParser { | |||
| public: | |||
| CaffePoolingParser() : CaffeNodeParser("pooling") {} | |||
| ~CaffePoolingParser() = default; | |||
| ~CaffePoolingParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| STATUS ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| static STATUS ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| STATUS ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| static STATUS ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| STATUS ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| static STATUS ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| STATUS ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| static STATUS ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,10 +18,6 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| static const float CAFFE_POWER_DEFAULT_POWER = 1.0; | |||
| static const float CAFFE_POWER_DEFAULT_SCALE = 1.0; | |||
| static const float CAFFE_POWER_DEFAULT_SHIFT = 0.0; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| @@ -43,15 +39,15 @@ STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::PowerParameter powerParam = proto.power_param(); | |||
| const caffe::PowerParameter &powerParam = proto.power_param(); | |||
| if (proto.has_power_param()) { | |||
| attr->power = powerParam.has_power() ? powerParam.power() : CAFFE_POWER_DEFAULT_POWER; | |||
| attr->scale = powerParam.has_scale() ? powerParam.scale() : CAFFE_POWER_DEFAULT_SCALE; | |||
| attr->shift = powerParam.has_shift() ? powerParam.shift() : CAFFE_POWER_DEFAULT_SHIFT; | |||
| attr->power = powerParam.has_power() ? powerParam.power() : 1.0; | |||
| attr->scale = powerParam.has_scale() ? powerParam.scale() : 1.0; | |||
| attr->shift = powerParam.has_shift() ? powerParam.shift() : 0.0; | |||
| } else { | |||
| attr->power = CAFFE_POWER_DEFAULT_POWER; | |||
| attr->scale = CAFFE_POWER_DEFAULT_SCALE; | |||
| attr->shift = CAFFE_POWER_DEFAULT_SHIFT; | |||
| attr->power = 1.0; | |||
| attr->scale = 1.0; | |||
| attr->shift = 0.0; | |||
| } | |||
| op->name = proto.name(); | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffePowerParser : public CaffeNodeParser { | |||
| public: | |||
| CaffePowerParser() : CaffeNodeParser("power") {} | |||
| ~CaffePowerParser() = default; | |||
| ~CaffePowerParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -22,6 +22,10 @@ namespace lite { | |||
| STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffePReluParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -38,7 +42,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::PReLUParameter pReluParam = proto.prelu_param(); | |||
| const caffe::PReLUParameter &pReluParam = proto.prelu_param(); | |||
| if (pReluParam.has_channel_shared()) { | |||
| attr->channelShared = pReluParam.channel_shared(); | |||
| } else { | |||
| @@ -49,7 +53,6 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| MS_LOG(ERROR) << "PRelu No blobs data in layer " << proto.name().c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| auto slope = ConvertWeight(weight.blobs(0)); | |||
| if (slope == nullptr) { | |||
| MS_LOG(ERROR) << "CaffePRelu convert slope for layer " << weight.name().c_str() << " failed."; | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffePReluParser : public CaffeNodeParser { | |||
| public: | |||
| CaffePReluParser() : CaffeNodeParser("pRelu") {} | |||
| ~CaffePReluParser() = default; | |||
| ~CaffePReluParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -39,7 +39,7 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::ReductionParameter reduce_param = proto.reduction_param(); | |||
| const caffe::ReductionParameter &reduce_param = proto.reduction_param(); | |||
| if (reduce_param.has_operation()) { | |||
| switch (reduce_param.operation()) { | |||
| case caffe::ReductionParameter_ReductionOp_MEAN: | |||
| @@ -72,6 +72,7 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||
| } | |||
| attr->reduceToEnd = true; | |||
| attr->keepDims = false; | |||
| op->name = proto.name(); | |||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | |||
| op->primitive->value.value = attr.release(); | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeReduceParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeReduceParser() : CaffeNodeParser("reduce") {} | |||
| ~CaffeReduceParser() = default; | |||
| ~CaffeReduceParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -39,8 +39,6 @@ STATUS CaffeRelu6Parser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| } | |||
| attr->type = schema::ActivationType_RELU6; | |||
| // relu: negative_slope = 0, no parameter; | |||
| // leakyrelu: negative_slope != 0; | |||
| if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) { | |||
| float negative_slope = proto.relu_param().negative_slope(); | |||
| if (0 != negative_slope) { | |||
| @@ -25,7 +25,7 @@ namespace lite { | |||
| class CaffeRelu6Parser : public CaffeNodeParser { | |||
| public: | |||
| CaffeRelu6Parser() : CaffeNodeParser("relu6") {} | |||
| ~CaffeRelu6Parser() = default; | |||
| ~CaffeRelu6Parser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -40,7 +40,7 @@ STATUS CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| attr->format = schema::Format::Format_NCHW; | |||
| const caffe::ReshapeParameter reshapeParam = proto.reshape_param(); | |||
| const caffe::ReshapeParameter &reshapeParam = proto.reshape_param(); | |||
| if (!reshapeParam.has_shape()) { | |||
| MS_LOG(ERROR) << "Reshape has no shape info, ret fail"; | |||
| return RET_ERROR; | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeReshapeParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeReshapeParser() : CaffeNodeParser("reshape") {} | |||
| ~CaffeReshapeParser() = default; | |||
| ~CaffeReshapeParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -17,14 +17,15 @@ | |||
| #include "tools/converter/parser/caffe/caffe_scale_parser.h" | |||
| #include <memory> | |||
| const int32_t NCHW_DIM_C = 1; | |||
| const int32_t DIM_DEFAULT_SIZE = 4; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeScaleParser"; | |||
| if (weightVec == nullptr) { | |||
| MS_LOG(ERROR) << "weightVec is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -47,10 +48,10 @@ STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| return RET_ERROR; | |||
| } | |||
| const caffe::ScaleParameter scaleParam = weight.scale_param(); | |||
| int axis = NCHW_DIM_C; | |||
| const caffe::ScaleParameter &scaleParam = weight.scale_param(); | |||
| int axis = 1; | |||
| if (scaleParam.has_axis()) { | |||
| uint32_t axis_index = NCHW_DIM_C; | |||
| uint32_t axis_index = 1; | |||
| if (GetAxisIndex(scaleParam.axis(), &axis_index)) { | |||
| MS_LOG(ERROR) << "scale get axis failed for layer " << weight.name().c_str(); | |||
| return RET_ERROR; | |||
| @@ -93,7 +94,7 @@ STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| } | |||
| STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) { | |||
| if (axis < -DIM_DEFAULT_SIZE || axis >= DIM_DEFAULT_SIZE) { | |||
| if (axis < -4 || axis >= 4) { | |||
| MS_LOG(ERROR) << "Scale axis value(" << axis << ") is not correct"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -102,7 +103,7 @@ STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) | |||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| *axis_index = (axis + DIM_DEFAULT_SIZE) % DIM_DEFAULT_SIZE; | |||
| *axis_index = (axis + 4) % 4; | |||
| return RET_OK; | |||
| } | |||
| @@ -26,12 +26,12 @@ namespace lite { | |||
| class CaffeScaleParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeScaleParser() : CaffeNodeParser("scale") {} | |||
| ~CaffeScaleParser() = default; | |||
| ~CaffeScaleParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); | |||
| static STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeSigmoidParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeSigmoidParser() : CaffeNodeParser("sigmoid") {} | |||
| ~CaffeSigmoidParser() = default; | |||
| ~CaffeSigmoidParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -33,7 +33,6 @@ STATUS CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| } | |||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| @@ -56,12 +55,12 @@ STATUS CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe:: | |||
| attr->sizeSplits = size_splits; | |||
| } | |||
| // The axis along which to slice -- may be negative to index from the end (e.g., -1 for the last axis). | |||
| if (slice_param.has_axis()) { | |||
| attr->splitDim = slice_param.axis(); | |||
| } else if (slice_param.has_slice_dim()) { | |||
| attr->splitDim = slice_param.slice_dim(); | |||
| } | |||
| op->name = proto.name(); | |||
| op->primitive->value.type = schema::PrimitiveType_Split; | |||
| op->primitive->value.value = attr.release(); | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeSliceParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeSliceParser() : CaffeNodeParser("slice") {} | |||
| ~CaffeSliceParser() = default; | |||
| ~CaffeSliceParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -17,8 +17,6 @@ | |||
| #include "tools/converter/parser/caffe/caffe_softmax_parser.h" | |||
| #include <memory> | |||
| static const int32_t CAFFE_SOFTMAX_DEFAULT_AXIS = 1; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| @@ -42,11 +40,11 @@ STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe | |||
| if (proto.has_softmax_param() && proto.softmax_param().has_axis()) { | |||
| if (proto.softmax_param().axis() == -1) { | |||
| MS_LOG(ERROR) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| MS_LOG(DEBUG) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| attr->axis = proto.softmax_param().axis(); | |||
| } else { | |||
| attr->axis = CAFFE_SOFTMAX_DEFAULT_AXIS; | |||
| attr->axis = 1; | |||
| } | |||
| op->name = proto.name(); | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeSoftmaxParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeSoftmaxParser() : CaffeNodeParser("softmax") {} | |||
| ~CaffeSoftmaxParser() = default; | |||
| ~CaffeSoftmaxParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeTanhParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeTanhParser() : CaffeNodeParser("tanh") {} | |||
| ~CaffeTanhParser() = default; | |||
| ~CaffeTanhParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -39,7 +39,7 @@ STATUS CaffeTileParser::Parse(const caffe::LayerParameter &proto, const caffe::L | |||
| return RET_NULL_PTR; | |||
| } | |||
| const caffe::TileParameter tile_param = proto.tile_param(); | |||
| const caffe::TileParameter &tile_param = proto.tile_param(); | |||
| std::vector<int> dims; | |||
| std::vector<int> multiples; | |||
| dims.clear(); | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| class CaffeTileParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeTileParser() : CaffeNodeParser("tile") {} | |||
| ~CaffeTileParser() = default; | |||
| ~CaffeTileParser() override = default; | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| @@ -18,6 +18,7 @@ | |||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||
| #include <memory> | |||
| #include <numeric> | |||
| #include <functional> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -266,7 +267,6 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| return RET_NULL_PTR; | |||
| } | |||
| // there is no Prod in onnx | |||
| if (onnx_node.op_type() == "Sum") { | |||
| attr->mode = schema::EltwiseMode_SUM; | |||
| } else if (onnx_node.op_type() == "Max") { | |||
| @@ -50,13 +50,6 @@ class OnnxDivParser : public OnnxNodeParser { | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| class OnnxMeanParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxMeanParser() : OnnxNodeParser("Mean") {} | |||
| ~OnnxMeanParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| class OnnxPowParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxPowParser() : OnnxNodeParser("Power") {} | |||
| @@ -38,7 +38,6 @@ STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| return RET_NULL_PTR; | |||
| } | |||
| // use channel dim as axis | |||
| attr->axis = {1}; | |||
| op->primitive->value.type = schema::PrimitiveType_BiasAdd; | |||
| @@ -52,7 +52,7 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| attr->value.push_back(static_cast<float>(onnx_node_attr.i())); | |||
| break; | |||
| case onnx::AttributeProto_AttributeType_TENSOR: { | |||
| auto tensor = onnx_node_attr.t(); | |||
| const auto &tensor = onnx_node_attr.t(); | |||
| auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| @@ -67,7 +67,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| // set default params | |||
| attr->strideH = 1; | |||
| attr->strideW = 1; | |||
| attr->dilateH = 1; | |||
| @@ -75,6 +75,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->group = 1; | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| attr->format = schema::Format::Format_NCHW; | |||
| // set opdef each attr params | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "group") { | |||
| @@ -157,8 +158,10 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | |||
| if (iter != (*nodeIter).attribute().end()) { | |||
| MS_ASSERT(iter->ints().begin() != nullptr); | |||
| MS_ASSERT(iter->ints().end() != nullptr); | |||
| if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | |||
| MS_LOG(ERROR) << "dims insert failed"; | |||
| return RET_ERROR; | |||
| } | |||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } | |||
| attr->channelOut = dims[0]; | |||
| @@ -28,7 +28,7 @@ class OnnxConverter : public Converter { | |||
| public: | |||
| OnnxConverter(); | |||
| ~OnnxConverter() = default; | |||
| ~OnnxConverter() override = default; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -71,14 +71,12 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| return RET_NULL_PTR; | |||
| } | |||
| // set default params | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| attr->group = 1; | |||
| attr->strideW = 1; | |||
| attr->strideH = 1; | |||
| attr->dilateW = 1; | |||
| attr->dilateH = 1; | |||
| // set opdef each attr params | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "group") { | |||
| attr->group = static_cast<int32_t>(onnx_node_attr.i()); | |||
| @@ -144,10 +142,14 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| std::vector<int> weight_shape; | |||
| auto size = (*nodeIter).dims_size(); | |||
| weight_shape.reserve(size); | |||
| for (int i = 0; i < size; ++i) { | |||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||
| } | |||
| MS_ASSERT(weight_shape.size() == 4); | |||
| if (weight_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | |||
| return RET_ERROR; | |||
| } | |||
| attr->channelIn = weight_shape[0]; | |||
| attr->channelOut = weight_shape[1] * attr->group; | |||
| @@ -31,7 +31,7 @@ class OnnxDeConvParser : public OnnxNodeParser { | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| private: | |||
| bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); | |||
| static bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -48,11 +48,10 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | |||
| return RET_ERROR; | |||
| } | |||
| const int64_t *dataPtr = nullptr; | |||
| for (const auto &attrPower : nodeIter->attribute()) { | |||
| if (attrPower.name() == "value") { | |||
| const auto &t = attrPower.t(); | |||
| dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data()); | |||
| auto *dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data()); | |||
| for (int i = 0; i < t.dims(0); ++i) { | |||
| dst_shape.emplace_back(dataPtr[i]); | |||
| } | |||
| @@ -25,7 +25,7 @@ namespace lite { | |||
| class OnnxLpNormParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} | |||
| ~OnnxLpNormParser() = default; | |||
| ~OnnxLpNormParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| @@ -39,7 +39,7 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "direction") { | |||
| auto direction = onnx_node_attr.s(); | |||
| const auto &direction = onnx_node_attr.s(); | |||
| attr->bidirection = direction == "bidirectional"; | |||
| } | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "google/protobuf/message.h" | |||
| #include "proto/onnx.pb.h" | |||
| @@ -29,13 +30,13 @@ namespace mindspore { | |||
| namespace lite { | |||
| class OnnxNodeParser { | |||
| public: | |||
| explicit OnnxNodeParser(const std::string nodeName) : name(nodeName) {} | |||
| explicit OnnxNodeParser(std::string nodeName) : name(std::move(nodeName)) {} | |||
| virtual ~OnnxNodeParser() = default; | |||
| virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; | |||
| STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | |||
| static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | |||
| static STATUS set_opset_version(int version) { | |||
| opset_version_ = version; | |||
| @@ -44,9 +45,9 @@ class OnnxNodeParser { | |||
| static int opset_version() { return opset_version_; } | |||
| protected: | |||
| schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); | |||
| static schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); | |||
| void Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr); | |||
| static void Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr); | |||
| const std::string name; | |||
| @@ -40,13 +40,6 @@ OnnxNodeParser *OnnxNodeParserRegistry::GetNodeParser(const std::string &name) { | |||
| if (it != parsers.end()) { | |||
| return it->second; | |||
| } | |||
| /* should not support vague name, otherwise may get wrong parser. ex. PRelu and Relu | |||
| for (auto const &i : parsers) { | |||
| if (name.find(i.first) != std::string::npos) { | |||
| return i.second; | |||
| } | |||
| } | |||
| */ | |||
| return nullptr; | |||
| } | |||
| } // namespace lite | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_pool_parser.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_relu_parser.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "securec/include/securec.h" | |||
| @@ -63,7 +62,6 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx PReluParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -113,7 +111,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| } | |||
| OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | |||
| OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxLeakeyReluParser()); | |||
| OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxReluParser()); | |||
| OnnxNodeRegistrar g_onnxPReluParser("PRelu", new OnnxPReluParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -30,12 +30,6 @@ class OnnxReluParser : public OnnxNodeParser { | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| class OnnxLeakeyReluParser : public OnnxReluParser { | |||
| public: | |||
| OnnxLeakeyReluParser() : OnnxReluParser() {} | |||
| ~OnnxLeakeyReluParser() override = default; | |||
| }; | |||
| class OnnxPReluParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxPReluParser() : OnnxNodeParser("Prelu") {} | |||
| @@ -43,7 +43,6 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->k = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| // attr->sorted; | |||
| op->primitive->value.type = schema::PrimitiveType_TopK; | |||
| op->primitive->value.value = attr.release(); | |||
| @@ -41,13 +41,7 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx | |||
| attr->conjugate = false; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axes") { | |||
| attr->perm.resize(onnx_node_attr.ints_size()); | |||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||
| attr->perm[i] = onnx_node_attr.ints(i); | |||
| } | |||
| } | |||
| if (attribute_name == "perm") { | |||
| if (attribute_name == "axes" || attribute_name == "perm") { | |||
| attr->perm.resize(onnx_node_attr.ints_size()); | |||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||
| attr->perm[i] = onnx_node_attr.ints(i); | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_upsample_parser.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| @@ -18,7 +18,6 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -26,6 +25,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -71,6 +73,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| } | |||
| attr->alpha = tflite_attr->alpha; | |||
| attr->type = schema::ActivationType_LEAKY_RELU; | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| @@ -81,12 +86,12 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | |||
| TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | |||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | |||
| TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser()); | |||
| TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); | |||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | |||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | |||
| TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -34,41 +34,6 @@ class TfliteActivationParser : public TfliteNodeParser { | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| }; | |||
| class TfliteReluParser : public TfliteActivationParser { | |||
| public: | |||
| TfliteReluParser() : TfliteActivationParser() {} | |||
| }; | |||
| class TfliteRelu6Parser : public TfliteActivationParser { | |||
| public: | |||
| TfliteRelu6Parser() : TfliteActivationParser() {} | |||
| }; | |||
| class TfliteTanhParser : public TfliteActivationParser { | |||
| public: | |||
| TfliteTanhParser() : TfliteActivationParser() {} | |||
| }; | |||
| class TfliteLogisticParser : public TfliteActivationParser { | |||
| public: | |||
| TfliteLogisticParser() : TfliteActivationParser() {} | |||
| }; | |||
| class TfliteSwishParser : public TfliteActivationParser { | |||
| public: | |||
| TfliteSwishParser() : TfliteActivationParser() {} | |||
| }; | |||
| class TfliteHardSwishParser : public TfliteActivationParser { | |||
| public: | |||
| TfliteHardSwishParser() : TfliteActivationParser() {} | |||
| }; | |||
| class TfliteLeakyReluParser : public TfliteActivationParser { | |||
| public: | |||
| TfliteLeakyReluParser() : TfliteActivationParser() {} | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,7 +18,6 @@ | |||
| #include "tools/converter/parser/tflite/tflite_addn_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -26,6 +25,9 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -43,11 +45,12 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||
| } | |||
| attr->N = tflite_subgraph->tensors.size() - 1; | |||
| op->primitive->value.type = schema::PrimitiveType_AddN; | |||
| op->primitive->value.value = attr.release(); | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| @@ -25,6 +25,9 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -48,7 +51,12 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| // get axis attr | |||
| auto axis_idx = tflite_op->inputs[1]; | |||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||
| auto axis_tensor = tflite_subgraph->tensors[axis_idx].get(); | |||
| if (axis_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto buffer_idx = axis_tensor->buffer; | |||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||
| if (buf_data == nullptr) { | |||
| MS_LOG(ERROR) << "the buf data is null"; | |||
| @@ -69,6 +77,6 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_TfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | |||
| TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,9 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -48,7 +51,12 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| // get axis attr | |||
| auto axis_idx = tflite_op->inputs[1]; | |||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||
| auto axis_tensor = tflite_subgraph->tensors[axis_idx].get(); | |||
| if (axis_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto buffer_idx = axis_tensor->buffer; | |||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||
| if (buf_data == nullptr) { | |||
| MS_LOG(ERROR) << "the buf data is null"; | |||
| @@ -69,6 +77,6 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_TfliteArgminParser("Argmin", new TfliteArgminParser()); | |||
| TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,7 +18,6 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -26,6 +25,9 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -165,11 +167,14 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Minimum; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| // set input | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| @@ -179,6 +184,9 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -303,6 +311,9 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Neg; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| @@ -314,6 +325,9 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -381,45 +395,48 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LessEqual; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); | |||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteSubParser()); | |||
| TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser()); | |||
| TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser()); | |||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser()); | |||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser()); | |||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser()); | |||
| TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser()); | |||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteSquaredDifferenceParser()); | |||
| TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser()); | |||
| TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser()); | |||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteMulParser("Mul", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteDivParser("Div", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tflitePowParser("Pow", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser()); | |||
| TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser()); | |||
| TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSqrtParser()); | |||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteRsqrtParser()); | |||
| TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSquareParser()); | |||
| TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSinParser()); | |||
| TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser()); | |||
| TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | |||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | |||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | |||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); | |||
| TfliteNodeRegister g_tfliteAbsParser("Abs", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteExpParser("Exp", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSquareParser("Square", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSinParser("Sin", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteCosParser("Cos", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteLogParser("Log", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | |||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteGreaterParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser()); | |||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser()); | |||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser()); | |||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteCompareOpParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -34,61 +34,6 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| }; | |||
| class TfliteAddParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteAddParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteSubParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteSubParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteMulParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteMulParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteDivParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteDivParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteFloorDivParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteFloorDivParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteFloorModParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteFloorModParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteSquaredDifferenceParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteSquaredDifferenceParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteRealDivParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteRealDivParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TflitePowParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TflitePowParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteMaximumParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteMaximumParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteMinimumParser : public TfliteDoubleInputOpParser { | |||
| public: | |||
| TfliteMinimumParser() : TfliteDoubleInputOpParser() {} | |||
| }; | |||
| class TfliteSingleInputOpParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | |||
| @@ -98,66 +43,6 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| }; | |||
| class TfliteAbsParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteAbsParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteExpParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteExpParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteSqrtParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteSqrtParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteSquareParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteSquareParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteSinParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteSinParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteCosParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteCosParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteRsqrtParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteRsqrtParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteLogParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteLogParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteRoundParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteRoundParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteCeilParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteCeilParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteFloorParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteFloorParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteNegParser : public TfliteSingleInputOpParser { | |||
| public: | |||
| TfliteNegParser() : TfliteSingleInputOpParser() {} | |||
| }; | |||
| class TfliteCompareOpParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | |||
| @@ -166,36 +51,6 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| }; | |||
| class TfliteEqualParser : public TfliteCompareOpParser { | |||
| public: | |||
| TfliteEqualParser() : TfliteCompareOpParser() {} | |||
| }; | |||
| class TfliteNotEqualParser : public TfliteCompareOpParser { | |||
| public: | |||
| TfliteNotEqualParser() : TfliteCompareOpParser() {} | |||
| }; | |||
| class TfliteGreaterParser : public TfliteCompareOpParser { | |||
| public: | |||
| TfliteGreaterParser() : TfliteCompareOpParser() {} | |||
| }; | |||
| class TfliteGreaterEqualParser : public TfliteCompareOpParser { | |||
| public: | |||
| TfliteGreaterEqualParser() : TfliteCompareOpParser() {} | |||
| }; | |||
| class TfliteLessParser : public TfliteCompareOpParser { | |||
| public: | |||
| TfliteLessParser() : TfliteCompareOpParser() {} | |||
| }; | |||
| class TfliteLessEqualParser : public TfliteCompareOpParser { | |||
| public: | |||
| TfliteLessEqualParser() : TfliteCompareOpParser() {} | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -27,6 +27,9 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -44,6 +47,9 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | |||
| } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>(); | |||
| @@ -70,6 +76,6 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| } | |||
| TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | |||
| TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser()); | |||
| TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -33,11 +33,6 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| }; | |||
| class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | |||
| public: | |||
| TfliteBatchToSpaceNDParser() : TfliteBatchToSpaceParser() {} | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||