| @@ -106,7 +106,7 @@ if (BUILD_CONVERTER) | |||||
| include_directories(${TOP_DIR}/third_party/protobuf/build/include) | include_directories(${TOP_DIR}/third_party/protobuf/build/include) | ||||
| link_directories(${TOP_DIR}/third_party/protobuf/build/lib) | link_directories(${TOP_DIR}/third_party/protobuf/build/lib) | ||||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) | ||||
| add_subdirectory(src/common/anf_exporter) | |||||
| add_subdirectory(src/common/anf_importer) | |||||
| endif() | endif() | ||||
| if (BUILD_DEVICE) | if (BUILD_DEVICE) | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "mindspore/core/ir/primitive.h" | #include "mindspore/core/ir/primitive.h" | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| // #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/primitive_t_value.h" | #include "src/ir/primitive_t_value.h" | ||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| @@ -148,27 +148,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| node->name = cnode->fullname_with_scope(); | node->name = cnode->fullname_with_scope(); | ||||
| node->nodeType = schema::NodeType_CNode; | node->nodeType = schema::NodeType_CNode; | ||||
| // populate primitive | // populate primitive | ||||
| if (primitive != nullptr) { | |||||
| primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| std::string opType = primitive->name(); | |||||
| auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| if (nodeParser == nullptr) { | |||||
| MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<schema::TensorT *> outputs; | |||||
| if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { | |||||
| auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); | |||||
| outputs.resize(abstract_cnode->size()); | |||||
| } | |||||
| nodeParser->Parse(cnode, node.get(), &outputs); | |||||
| SetOpInputNode(cnode, metaGraphT.get(), node.get()); | |||||
| SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); | |||||
| metaGraphT->nodes.emplace_back(std::move(node)); | |||||
| continue; | |||||
| } | |||||
| // if (primitive != nullptr) { | |||||
| // primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| // MS_ASSERT(primitive != nullptr); | |||||
| // std::string opType = primitive->name(); | |||||
| // auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| // if (nodeParser == nullptr) { | |||||
| // MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||||
| // return nullptr; | |||||
| // } | |||||
| // std::vector<schema::TensorT *> outputs; | |||||
| // if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { | |||||
| // auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); | |||||
| // outputs.resize(abstract_cnode->size()); | |||||
| // } | |||||
| // | |||||
| // nodeParser->Parse(cnode, node.get(), &outputs); | |||||
| // SetOpInputNode(cnode, metaGraphT.get(), node.get()); | |||||
| // SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); | |||||
| // metaGraphT->nodes.emplace_back(std::move(node)); | |||||
| // continue; | |||||
| // } | |||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | ||||
| if (primitiveT_value == nullptr) { | if (primitiveT_value == nullptr) { | ||||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | ||||
| @@ -1,7 +1,7 @@ | |||||
| file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| *.cc | *.cc | ||||
| ) | ) | ||||
| add_library(anf_exporter_mid OBJECT | |||||
| list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc) | |||||
| add_library(anf_importer_mid OBJECT | |||||
| ${ANF_SRC_LIST} | ${ANF_SRC_LIST} | ||||
| ) | ) | ||||
| @@ -13,33 +13,33 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_activation_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::ActivationT>(); | auto attr = std::make_unique<schema::ActivationT>(); | ||||
| if (p->name() == "ReLU") { | |||||
| if (prim->name() == "ReLU") { | |||||
| attr->type = schema::ActivationType_RELU; | attr->type = schema::ActivationType_RELU; | ||||
| } else if (p->name() == "Sigmoid") { | |||||
| } else if (prim->name() == "Sigmoid") { | |||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (p->name() == "ReLU6") { | |||||
| } else if (prim->name() == "ReLU6") { | |||||
| attr->type = schema::ActivationType_RELU6; | attr->type = schema::ActivationType_RELU6; | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Activation; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); | |||||
| AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater()); | |||||
| AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); | |||||
| AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater()); | |||||
| AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater()); | |||||
| AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -16,14 +16,15 @@ | |||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfActivationPopulater : public AnfNodePopulater { | class AnfActivationPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfActivationPopulater() = default; | AnfActivationPopulater() = default; | ||||
| ~AnfActivationPopulater() override = default; | ~AnfActivationPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,25 +13,24 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::FusedBatchNormT>(); | auto attr = std::make_unique<schema::FusedBatchNormT>(); | ||||
| attr->epsilon = GetValue<float>(p->GetAttr("epsilon")); | |||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||||
| node->primitive->value.value = attr.release(); | |||||
| attr->epsilon = GetValue<float>(prim->GetAttr("epsilon")); | |||||
| primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser()); | |||||
| AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #define MINDSPORE_ANF_BATCHNORM_PARSER_H | #define MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfBatchnormParser : public AnfNodePopulater { | |||||
| class AnfBatchnormPopulater : public AnfNodePopulater { | |||||
| public: | public: | ||||
| AnfBatchnormParser() = default; | |||||
| ~AnfBatchnormParser() override = default; | |||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| AnfBatchnormPopulater() = default; | |||||
| ~AnfBatchnormPopulater() override = default; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,25 +13,25 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::BiasAddT>(); | auto attr = std::make_unique<schema::BiasAddT>(); | ||||
| attr->axis = {0}; | attr->axis = {0}; | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_BiasAdd; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_BiasAdd; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater()); | |||||
| AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_BIASADD_PARSER_H | #ifndef MINDSPORE_ANF_BIASADD_PARSER_H | ||||
| #define MINDSPORE_ANF_BIASADD_PARSER_H | #define MINDSPORE_ANF_BIASADD_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfBiasAddPopulater : public AnfNodePopulater { | class AnfBiasAddPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfBiasAddPopulater() = default; | AnfBiasAddPopulater() = default; | ||||
| ~AnfBiasAddPopulater() override = default; | ~AnfBiasAddPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -16,30 +16,27 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_concat_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_concat_populater.h" | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::ConcatT>(); | auto attr = std::make_unique<schema::ConcatT>(); | ||||
| auto prim_axis = GetValue<int>(p->GetAttr("axis")); | |||||
| auto prim_axis = GetValue<int>(prim->GetAttr("axis")); | |||||
| attr->axis = prim_axis; | attr->axis = prim_axis; | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Concat; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Concat; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater()); | |||||
| AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -18,14 +18,15 @@ | |||||
| #ifndef MINDSPORE_ANF_CONCAT_PARSER_H | #ifndef MINDSPORE_ANF_CONCAT_PARSER_H | ||||
| #define MINDSPORE_ANF_CONCAT_PARSER_H | #define MINDSPORE_ANF_CONCAT_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfConcatPopulater : public AnfNodePopulater { | class AnfConcatPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfConcatPopulater() = default; | AnfConcatPopulater() = default; | ||||
| ~AnfConcatPopulater() override = default; | ~AnfConcatPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -16,23 +16,22 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_conv_populater.h" | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int group = GetValue<int>(p->GetAttr("group")); | |||||
| int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| int group = GetValue<int>(prim->GetAttr("group")); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (group > 1) { | if (group > 1) { | ||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | ||||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | if (format == "NCHW") { | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| } else if (format == "NHWC") { | } else if (format == "NHWC") { | ||||
| @@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||||
| } else { | } else { | ||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list")); | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | attr->padUp = pad_list[0]; | ||||
| attr->padDown = pad_list[1]; | attr->padDown = pad_list[1]; | ||||
| attr->padLeft = pad_list[2]; | attr->padLeft = pad_list[2]; | ||||
| attr->padRight = pad_list[3]; | attr->padRight = pad_list[3]; | ||||
| auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | attr->dilateH = dilation[0]; | ||||
| attr->dilateW = dilation[1]; | attr->dilateW = dilation[1]; | ||||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | attr->kernelH = kernel_size[0]; | ||||
| attr->kernelW = kernel_size[1]; | attr->kernelW = kernel_size[1]; | ||||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | attr->strideH = stride[2]; | ||||
| attr->strideW = stride[3]; | attr->strideW = stride[3]; | ||||
| auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | if (pad_mode == "valid") { | ||||
| attr->padMode = schema::PadMode_VALID; | attr->padMode = schema::PadMode_VALID; | ||||
| } else if (pad_mode == "same") { | } else if (pad_mode == "same") { | ||||
| @@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } else { | } else { | ||||
| auto attr = std::make_unique<schema::Conv2DT>(); | auto attr = std::make_unique<schema::Conv2DT>(); | ||||
| attr->group = group; | attr->group = group; | ||||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | if (format == "NCHW") { | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| } else if (format == "NHWC") { | } else if (format == "NHWC") { | ||||
| @@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||||
| } else { | } else { | ||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list")); | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | attr->padUp = pad_list[0]; | ||||
| attr->padDown = pad_list[1]; | attr->padDown = pad_list[1]; | ||||
| attr->padLeft = pad_list[2]; | attr->padLeft = pad_list[2]; | ||||
| attr->padRight = pad_list[3]; | attr->padRight = pad_list[3]; | ||||
| auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | attr->dilateH = dilation[0]; | ||||
| attr->dilateW = dilation[1]; | attr->dilateW = dilation[1]; | ||||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | attr->kernelH = kernel_size[0]; | ||||
| attr->kernelW = kernel_size[1]; | attr->kernelW = kernel_size[1]; | ||||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | attr->strideH = stride[2]; | ||||
| attr->strideW = stride[3]; | attr->strideW = stride[3]; | ||||
| attr->channelOut = GetValue<int>(p->GetAttr("out_channel")); | |||||
| attr->channelOut = GetValue<int>(prim->GetAttr("out_channel")); | |||||
| auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | if (pad_mode == "valid") { | ||||
| attr->padMode = schema::PadMode_VALID; | attr->padMode = schema::PadMode_VALID; | ||||
| } else if (pad_mode == "same") { | } else if (pad_mode == "same") { | ||||
| @@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||||
| } else { | } else { | ||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| } | } | ||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | } | ||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater()); | |||||
| AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -18,14 +18,15 @@ | |||||
| #ifndef MINDSPORE_ANF_CONV_PARSER_H | #ifndef MINDSPORE_ANF_CONV_PARSER_H | ||||
| #define MINDSPORE_ANF_CONV_PARSER_H | #define MINDSPORE_ANF_CONV_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfConvPopulater : public AnfNodePopulater { | class AnfConvPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfConvPopulater() = default; | AnfConvPopulater() = default; | ||||
| ~AnfConvPopulater() override = default; | ~AnfConvPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,21 +13,21 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | ||||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | if (format == "NCHW") { | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| } else if (format == "NHWC") { | } else if (format == "NHWC") { | ||||
| @@ -35,25 +35,25 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP | |||||
| } else { | } else { | ||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads")); | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads")); | |||||
| attr->padUp = pad_list[0]; | attr->padUp = pad_list[0]; | ||||
| attr->padDown = pad_list[1]; | attr->padDown = pad_list[1]; | ||||
| attr->padLeft = pad_list[2]; | attr->padLeft = pad_list[2]; | ||||
| attr->padRight = pad_list[3]; | attr->padRight = pad_list[3]; | ||||
| auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | attr->dilateH = dilation[0]; | ||||
| attr->dilateW = dilation[1]; | attr->dilateW = dilation[1]; | ||||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | attr->kernelH = kernel_size[0]; | ||||
| attr->kernelW = kernel_size[1]; | attr->kernelW = kernel_size[1]; | ||||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | attr->strideH = stride[2]; | ||||
| attr->strideW = stride[3]; | attr->strideW = stride[3]; | ||||
| auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | if (pad_mode == "valid") { | ||||
| attr->padMode = schema::PadMode_VALID; | attr->padMode = schema::PadMode_VALID; | ||||
| } else if (pad_mode == "same") { | } else if (pad_mode == "same") { | ||||
| @@ -62,11 +62,11 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP | |||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| } | } | ||||
| auto channel_multiplier = GetValue<int>(p->GetAttr("channel_multiplier")); | |||||
| auto channel_multiplier = GetValue<int>(prim->GetAttr("channel_multiplier")); | |||||
| attr->channelMultiplier = channel_multiplier; | attr->channelMultiplier = channel_multiplier; | ||||
| MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); | |||||
| auto inputNode = cnodePtr->input(kAnfPopulaterTwo); | |||||
| MS_ASSERT(inputs.size() == kAnfPopulaterThree); | |||||
| auto inputNode = inputs[kAnfPopulaterTwo]; | |||||
| MS_ASSERT(inputNode != nullptr); | MS_ASSERT(inputNode != nullptr); | ||||
| if (inputNode->isa<Parameter>()) { | if (inputNode->isa<Parameter>()) { | ||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | auto paramNode = inputNode->cast<ParameterPtr>(); | ||||
| @@ -82,12 +82,12 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP | |||||
| } | } | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | ||||
| #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfDepwiseconv2DPopulater() = default; | AnfDepwiseconv2DPopulater() = default; | ||||
| ~AnfDepwiseconv2DPopulater() override = default; | ~AnfDepwiseconv2DPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,23 +13,24 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_dequant_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::OnnxInt8DequantizeT>(); | auto attr = std::make_unique<schema::OnnxInt8DequantizeT>(); | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater()); | |||||
| AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H | #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H | ||||
| #define MINDSPORE_ANF_DEQUANT_PARSER_H | #define MINDSPORE_ANF_DEQUANT_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfDequantPopulater : public AnfNodePopulater { | class AnfDequantPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfDequantPopulater() = default; | AnfDequantPopulater() = default; | ||||
| ~AnfDequantPopulater() override = default; | ~AnfDequantPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,23 +13,24 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_flatten_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::FlattenT>(); | auto attr = std::make_unique<schema::FlattenT>(); | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Flatten; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Flatten; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater()); | |||||
| AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H | #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H | ||||
| #define MINDSPORE_ANF_FLATTEN_PARSER_H | #define MINDSPORE_ANF_FLATTEN_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfFlattenPopulater : public AnfNodePopulater { | class AnfFlattenPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfFlattenPopulater() = default; | AnfFlattenPopulater() = default; | ||||
| ~AnfFlattenPopulater() override = default; | ~AnfFlattenPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,26 +13,26 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::MatMulT>(); | auto attr = std::make_unique<schema::MatMulT>(); | ||||
| attr->transposeA = GetValue<bool>(p->GetAttr("transpose_a")); | |||||
| attr->transposeB = GetValue<bool>(p->GetAttr("transpose_b")); | |||||
| attr->transposeA = GetValue<bool>(prim->GetAttr("transpose_a")); | |||||
| attr->transposeB = GetValue<bool>(prim->GetAttr("transpose_b")); | |||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_MatMul; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_MATMUL_PARSER_H | #ifndef MINDSPORE_ANF_MATMUL_PARSER_H | ||||
| #define MINDSPORE_ANF_MATMUL_PARSER_H | #define MINDSPORE_ANF_MATMUL_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfMatmulPopulater : public AnfNodePopulater { | class AnfMatmulPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfMatmulPopulater() = default; | AnfMatmulPopulater() = default; | ||||
| ~AnfMatmulPopulater() override = default; | ~AnfMatmulPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,23 +13,23 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_mul_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_mul_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::MulT>(); | auto attr = std::make_unique<schema::MulT>(); | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Mul; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfMulPopulater : public AnfNodePopulater { | class AnfMulPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfMulPopulater() = default; | AnfMulPopulater() = default; | ||||
| ~AnfMulPopulater() override = default; | ~AnfMulPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -14,6 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| namespace mindspore::lite {} // namespace mindspore::lite | namespace mindspore::lite {} // namespace mindspore::lite | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| constexpr int kAnfPopulaterOne = 1; | constexpr int kAnfPopulaterOne = 1; | ||||
| @@ -28,7 +29,9 @@ class AnfNodePopulater { | |||||
| public: | public: | ||||
| AnfNodePopulater() = default; | AnfNodePopulater() = default; | ||||
| virtual ~AnfNodePopulater() = default; | virtual ~AnfNodePopulater() = default; | ||||
| virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) = 0; | |||||
| virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) = 0; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -14,14 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include <string> | #include <string> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" | |||||
| #include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" | |||||
| #include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" | |||||
| #include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" | |||||
| #include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" | |||||
| #include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { | AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { | ||||
| @@ -29,13 +23,13 @@ AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { | |||||
| return &instance; | return &instance; | ||||
| } | } | ||||
| AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { | AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { | ||||
| if (parsers.find(name) == parsers.end()) { | |||||
| if (populaters.find(name) == populaters.end()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return parsers[name]; | |||||
| return populaters[name]; | |||||
| } | } | ||||
| void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *parser) { | |||||
| parsers[name] = parser; | |||||
| void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *populater) { | |||||
| populaters[name] = populater; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -16,7 +16,7 @@ | |||||
| #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | ||||
| #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <string> | #include <string> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -26,16 +26,16 @@ class AnfNodePopulaterRegistry { | |||||
| virtual ~AnfNodePopulaterRegistry() = default; | virtual ~AnfNodePopulaterRegistry() = default; | ||||
| static AnfNodePopulaterRegistry *GetInstance(); | static AnfNodePopulaterRegistry *GetInstance(); | ||||
| AnfNodePopulater *GetNodePopulater(const std::string &name); | AnfNodePopulater *GetNodePopulater(const std::string &name); | ||||
| void SetNodePopulater(const std::string &name, AnfNodePopulater *parser); | |||||
| void SetNodePopulater(const std::string &name, AnfNodePopulater *populater); | |||||
| private: | private: | ||||
| std::unordered_map<std::string, AnfNodePopulater *> parsers; | |||||
| std::unordered_map<std::string, AnfNodePopulater *> populaters; | |||||
| }; | }; | ||||
| class AnfNodePopulaterRegistrar { | class AnfNodePopulaterRegistrar { | ||||
| public: | public: | ||||
| AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *parser) { | |||||
| AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, parser); | |||||
| AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *populater) { | |||||
| AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, populater); | |||||
| } | } | ||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,26 +13,26 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_pool_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::PoolingT>(); | auto attr = std::make_unique<schema::PoolingT>(); | ||||
| if (p->instance_name() == "MaxPool") { | |||||
| if (prim->instance_name() == "MaxPool") { | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | attr->poolingMode = schema::PoolMode_MAX_POOLING; | ||||
| } else if (p->instance_name() == "MeanPool") { | |||||
| } else if (prim->instance_name() == "MeanPool") { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | attr->poolingMode = schema::PoolMode_MEAN_POOLING; | ||||
| } | } | ||||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | if (format == "NCHW") { | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| } else if (format == "NHWC") { | } else if (format == "NHWC") { | ||||
| @@ -41,7 +41,7 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_mode = GetValue<std::string>(p->GetAttr("padding")); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("padding")); | |||||
| if (pad_mode == "VALID") { | if (pad_mode == "VALID") { | ||||
| attr->padMode = schema::PadMode_VALID; | attr->padMode = schema::PadMode_VALID; | ||||
| } else if (pad_mode == "SAME") { | } else if (pad_mode == "SAME") { | ||||
| @@ -50,19 +50,19 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| } | } | ||||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("ksize")); | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("ksize")); | |||||
| attr->windowH = kernel_size[2]; | attr->windowH = kernel_size[2]; | ||||
| attr->windowW = kernel_size[3]; | attr->windowW = kernel_size[3]; | ||||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("strides")); | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("strides")); | |||||
| attr->strideH = stride[2]; | attr->strideH = stride[2]; | ||||
| attr->strideW = stride[3]; | attr->strideW = stride[3]; | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfMaxPoolParser("MaxPool", new AnfPoolPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMaxPoolPopulater("MaxPool", new AnfPoolPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_POOL_PARSER_H | #ifndef MINDSPORE_ANF_POOL_PARSER_H | ||||
| #define MINDSPORE_ANF_POOL_PARSER_H | #define MINDSPORE_ANF_POOL_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfPoolPopulater : public AnfNodePopulater { | class AnfPoolPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfPoolPopulater() = default; | AnfPoolPopulater() = default; | ||||
| ~AnfPoolPopulater() override = default; | ~AnfPoolPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,23 +13,24 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_quant_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_quant_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::OnnxInt8QuantizeT>(); | auto attr = std::make_unique<schema::OnnxInt8QuantizeT>(); | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater()); | |||||
| AnfNodePopulaterRegistrar anfQuantPopulater("Quant", new AnfQuantPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_QUANT_PARSER_H | #ifndef MINDSPORE_ANF_QUANT_PARSER_H | ||||
| #define MINDSPORE_ANF_QUANT_PARSER_H | #define MINDSPORE_ANF_QUANT_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfQuantPopulater : public AnfNodePopulater { | class AnfQuantPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfQuantPopulater() = default; | AnfQuantPopulater() = default; | ||||
| ~AnfQuantPopulater() override = default; | ~AnfQuantPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_reducemean_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_reducemean_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -25,15 +25,15 @@ namespace { | |||||
| constexpr int kReduceInputNum = 3; | constexpr int kReduceInputNum = 3; | ||||
| constexpr int kReduceInputIndex = 2; | constexpr int kReduceInputIndex = 2; | ||||
| } | } | ||||
| int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| auto p = GetCNodePrimitive(cnodePtr); | |||||
| int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::ReduceT>(); | auto attr = std::make_unique<schema::ReduceT>(); | ||||
| attr->mode = schema::ReduceMode_ReduceMean; | attr->mode = schema::ReduceMode_ReduceMean; | ||||
| attr->keepDims = GetValue<bool>(p->GetAttr("keep_dims")); | |||||
| if (cnodePtr->inputs().size() == kReduceInputNum) { | |||||
| auto inputNode = cnodePtr->input(kReduceInputIndex); | |||||
| attr->keepDims = GetValue<bool>(prim->GetAttr("keep_dims")); | |||||
| if (inputs.size() == kReduceInputNum) { | |||||
| auto inputNode = inputs[kReduceInputIndex]; | |||||
| MS_ASSERT(inputNode != nullptr); | MS_ASSERT(inputNode != nullptr); | ||||
| if (inputNode->isa<ValueNode>()) { | if (inputNode->isa<ValueNode>()) { | ||||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | auto valueNode = inputNode->cast<ValueNodePtr>(); | ||||
| @@ -52,11 +52,11 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN | |||||
| } | } | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Reduce; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Reduce; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfReduceMeanParser("ReduceMean", new AnfReduceMeanPopulater()); | |||||
| AnfNodePopulaterRegistrar anfReduceMeanPopulater("ReduceMean", new AnfReduceMeanPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfReduceMeanPopulater : public AnfNodePopulater { | class AnfReduceMeanPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfReduceMeanPopulater() = default; | AnfReduceMeanPopulater() = default; | ||||
| ~AnfReduceMeanPopulater() override = default; | ~AnfReduceMeanPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,19 +13,20 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_reshape_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::ReshapeT>(); | auto attr = std::make_unique<schema::ReshapeT>(); | ||||
| MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); | |||||
| auto inputNode = cnodePtr->input(kAnfPopulaterTwo); | |||||
| MS_ASSERT(inputs.size() == kAnfPopulaterThree); | |||||
| auto inputNode = inputs[kAnfPopulaterTwo]; | |||||
| if (inputNode->isa<ValueNode>()) { | if (inputNode->isa<ValueNode>()) { | ||||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | auto valueNode = inputNode->cast<ValueNodePtr>(); | ||||
| MS_ASSERT(valueNode != nullptr); | MS_ASSERT(valueNode != nullptr); | ||||
| @@ -42,12 +43,12 @@ int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, sc | |||||
| } | } | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater()); | |||||
| AnfNodePopulaterRegistrar anfReshapePopulater("Reshape", new AnfReshapePopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -16,14 +16,15 @@ | |||||
| #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H | #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H | ||||
| #define MINDSPORE_ANF_RESHAPE_PARSER_H | #define MINDSPORE_ANF_RESHAPE_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfReshapePopulater : public AnfNodePopulater { | class AnfReshapePopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfReshapePopulater() = default; | AnfReshapePopulater() = default; | ||||
| ~AnfReshapePopulater() override = default; | ~AnfReshapePopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,22 +13,23 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_tensoradd_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfTensorAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::AddT>(); | auto attr = std::make_unique<schema::AddT>(); | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Add; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Add; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfTensorAddParser("TensorAdd", new AnfTensorAddPopulater()); | |||||
| AnfNodePopulaterRegistrar anfTensorAddPopulater("TensorAdd", new AnfTensorAddPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfTensorAddPopulater : public AnfNodePopulater { | class AnfTensorAddPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfTensorAddPopulater() = default; | AnfTensorAddPopulater() = default; | ||||
| ~AnfTensorAddPopulater() override = default; | ~AnfTensorAddPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,21 +13,21 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_transpose_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_transpose_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::TransposeT>(); | auto attr = std::make_unique<schema::TransposeT>(); | ||||
| MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); | |||||
| auto inputNode = cnodePtr->input(kAnfPopulaterTwo); | |||||
| MS_ASSERT(inputs.size() == kAnfPopulaterThree); | |||||
| auto inputNode = inputs[kAnfPopulaterTwo]; | |||||
| if (inputNode->isa<ValueNode>()) { | if (inputNode->isa<ValueNode>()) { | ||||
| auto valNode = inputNode->cast<ValueNodePtr>(); | auto valNode = inputNode->cast<ValueNodePtr>(); | ||||
| MS_ASSERT(valNode != nullptr); | MS_ASSERT(valNode != nullptr); | ||||
| @@ -44,11 +44,11 @@ int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, | |||||
| } | } | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_Transpose; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Transpose; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfTransposeParser("Transpose", new AnfTransposePopulater()); | |||||
| AnfNodePopulaterRegistrar anfTransposePopulater("Transpose", new AnfTransposePopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H | #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H | ||||
| #define MINDSPORE_ANF_TRANSPOSE_PARSER_H | #define MINDSPORE_ANF_TRANSPOSE_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfTransposePopulater : public AnfNodePopulater { | class AnfTransposePopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfTransposePopulater() = default; | AnfTransposePopulater() = default; | ||||
| ~AnfTransposePopulater() override = default; | ~AnfTransposePopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,22 +13,23 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int mindspore::lite::AnfTupleGetItemPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, | |||||
| std::vector<schema::TensorT *> *outputs) { | |||||
| int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto attr = std::make_unique<schema::TupleGetItemT>(); | auto attr = std::make_unique<schema::TupleGetItemT>(); | ||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| node->primitive->value.type = schema::PrimitiveType_TupleGetItem; | |||||
| node->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_TupleGetItem; | |||||
| primitive->value.value = attr.release(); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfTupleGetItemParser("tuple_getitem", new AnfTupleGetItemPopulater()); | |||||
| AnfNodePopulaterRegistrar anfTupleGetItemPopulater("tuple_getitem", new AnfTupleGetItemPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,14 +15,15 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #define MINDSPORE_ANF_BATCHNORM_PARSER_H | #define MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfTupleGetItemPopulater : public AnfNodePopulater { | class AnfTupleGetItemPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| AnfTupleGetItemPopulater() = default; | AnfTupleGetItemPopulater() = default; | ||||
| ~AnfTupleGetItemPopulater() override = default; | ~AnfTupleGetItemPopulater() override = default; | ||||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "schema/inner/model_generated.h" | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | #include "google/protobuf/io/zero_copy_stream_impl.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -38,6 +39,7 @@ | |||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| using string = std::string; | using string = std::string; | ||||
| using int32 = int32_t; | using int32 = int32_t; | ||||
| @@ -997,10 +999,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.clear(); | inputs.clear(); | ||||
| inputs.push_back(NewValueNode(prim)); | |||||
| for (int i = 0; i < node_proto.input_size(); ++i) { | for (int i = 0; i < node_proto.input_size(); ++i) { | ||||
| const std::string &input_name = node_proto.input(i); | const std::string &input_name = node_proto.input(i); | ||||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | ||||
| @@ -1009,6 +1009,18 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||||
| } | } | ||||
| inputs.push_back(anfnode_build_map_[input_name]); | inputs.push_back(anfnode_build_map_[input_name]); | ||||
| } | } | ||||
| std::string opType = prim->name(); | |||||
| auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| if (node_parser == nullptr) { | |||||
| MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||||
| return nullptr; | |||||
| } | |||||
| auto primitiveT = std::make_unique<schema::PrimitiveT>(); | |||||
| // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); | |||||
| std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release()); | |||||
| node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||||
| inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); | |||||
| CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | MS_EXCEPTION_IF_NULL(cnode_ptr); | ||||
| if (node_type == "LayerNorm") { | if (node_type == "LayerNorm") { | ||||
| @@ -215,9 +215,9 @@ if(BUILD_CONVERTER) | |||||
| ${TEST_CASE_TFLITE_PARSERS_SRC} | ${TEST_CASE_TFLITE_PARSERS_SRC} | ||||
| ${TOP_DIR}/mindspore/core/utils/flags.cc | ${TOP_DIR}/mindspore/core/utils/flags.cc | ||||
| ${LITE_DIR}/tools/converter/optimizer.cc | ${LITE_DIR}/tools/converter/optimizer.cc | ||||
| ${LITE_DIR}/src/common/anf_importer/anf_importer.cc | |||||
| ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc | |||||
| ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||||
| ${LITE_DIR}/tools/converter/anf_transform.cc | ${LITE_DIR}/tools/converter/anf_transform.cc | ||||
| ${LITE_DIR}/tools/converter/graphdef_transform.cc | ${LITE_DIR}/tools/converter/graphdef_transform.cc | ||||
| ${LITE_DIR}/tools/converter/converter_flags.cc | ${LITE_DIR}/tools/converter/converter_flags.cc | ||||
| @@ -345,7 +345,7 @@ if (BUILD_MINDDATA) | |||||
| endif() | endif() | ||||
| if (BUILD_CONVERTER) | if (BUILD_CONVERTER) | ||||
| target_link_libraries(lite-test | target_link_libraries(lite-test | ||||
| anf_exporter_mid | |||||
| anf_importer_mid | |||||
| tflite_parser_mid | tflite_parser_mid | ||||
| caffe_parser_mid | caffe_parser_mid | ||||
| node_mid | node_mid | ||||
| @@ -70,9 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/anf_importer.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_meta_graphT.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc | ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc | ||||
| ../optimizer/common/node_pass_extends.cc | ../optimizer/common/node_pass_extends.cc | ||||
| @@ -99,7 +97,7 @@ add_executable(converter_lite | |||||
| target_link_libraries(converter_lite PRIVATE | target_link_libraries(converter_lite PRIVATE | ||||
| tflite_parser_mid | tflite_parser_mid | ||||
| caffe_parser_mid | caffe_parser_mid | ||||
| anf_exporter_mid | |||||
| anf_importer_mid | |||||
| node_mid | node_mid | ||||
| graph_pass_mid | graph_pass_mid | ||||
| fusion_mid | fusion_mid | ||||
| @@ -10,6 +10,7 @@ add_library(quantizer_mid OBJECT | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/common/anf_exporter/anf_exporter.cc | |||||
| ) | ) | ||||
| if(ENABLE_ASAN) | if(ENABLE_ASAN) | ||||