diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index aff588fc5e..36d2a4c432 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -106,7 +106,7 @@ if (BUILD_CONVERTER) include_directories(${TOP_DIR}/third_party/protobuf/build/include) link_directories(${TOP_DIR}/third_party/protobuf/build/lib) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) - add_subdirectory(src/common/anf_exporter) + add_subdirectory(src/common/anf_importer) endif() if (BUILD_DEVICE) diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index d565d16795..81cdbe9429 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -25,7 +25,7 @@ #include "abstract/abstract_value.h" #include "base/core_ops.h" #include "mindspore/core/ir/primitive.h" -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +// #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/ir/primitive_t_value.h" #include "src/ir/tensor.h" #include "src/param_value_lite.h" @@ -148,27 +148,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { node->name = cnode->fullname_with_scope(); node->nodeType = schema::NodeType_CNode; // populate primitive - if (primitive != nullptr) { - primitive = GetValueNode(cnode->input(0)); - MS_ASSERT(primitive != nullptr); - std::string opType = primitive->name(); - auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); - if (nodeParser == nullptr) { - MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; - return nullptr; - } - std::vector outputs; - if (utils::isa(cnode->abstract())) { - auto abstract_cnode = utils::cast(cnode->abstract()); - outputs.resize(abstract_cnode->size()); - } - - nodeParser->Parse(cnode, node.get(), &outputs); - SetOpInputNode(cnode, metaGraphT.get(), node.get()); - SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); - metaGraphT->nodes.emplace_back(std::move(node)); - continue; - } + // if (primitive != nullptr) { + // primitive = GetValueNode(cnode->input(0)); + // MS_ASSERT(primitive != nullptr); + // std::string opType = primitive->name(); + // auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + // if (nodeParser == nullptr) { + // MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; + // return nullptr; + // } + // std::vector outputs; + // if (utils::isa(cnode->abstract())) { + // auto abstract_cnode = utils::cast(cnode->abstract()); + // outputs.resize(abstract_cnode->size()); + // } + // + // nodeParser->Parse(cnode, node.get(), &outputs); + // SetOpInputNode(cnode, metaGraphT.get(), node.get()); + // SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); + // metaGraphT->nodes.emplace_back(std::move(node)); + // continue; + // } auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; diff --git a/mindspore/lite/src/common/anf_exporter/CMakeLists.txt b/mindspore/lite/src/common/anf_importer/CMakeLists.txt similarity index 57% rename from mindspore/lite/src/common/anf_exporter/CMakeLists.txt rename to mindspore/lite/src/common/anf_importer/CMakeLists.txt index 352f59947a..07111121bf 100644 --- a/mindspore/lite/src/common/anf_exporter/CMakeLists.txt +++ b/mindspore/lite/src/common/anf_importer/CMakeLists.txt @@ -1,7 +1,7 @@ file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc ) -add_library(anf_exporter_mid OBJECT +list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc) +add_library(anf_importer_mid OBJECT ${ANF_SRC_LIST} ) - diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc similarity index 51% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc index bf6a66e57d..942b6b4311 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc @@ -13,33 +13,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" +#include "src/common/anf_importer/anf_populater/anf_activation_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - if (p->name() == "ReLU") { + if (prim->name() == "ReLU") { attr->type = schema::ActivationType_RELU; - } else if (p->name() == "Sigmoid") { + } else if (prim->name() == "Sigmoid") { attr->type = schema::ActivationType_SIGMOID; - } else if (p->name() == "ReLU6") { + } else if (prim->name() == "ReLU6") { attr->type = schema::ActivationType_RELU6; } - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Activation; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); -AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater()); -AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h index daa4add19c..ab43ff7d77 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h @@ -16,14 +16,15 @@ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfActivationPopulater : public AnfNodePopulater { public: AnfActivationPopulater() = default; ~AnfActivationPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc similarity index 54% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc index d8013aed14..f66c61562b 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc @@ -13,25 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h" +#include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - attr->epsilon = GetValue(p->GetAttr("epsilon")); - - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; - node->primitive->value.value = attr.release(); + attr->epsilon = GetValue(prim->GetAttr("epsilon")); + primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser()); +AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h similarity index 70% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h index 1df83a87ac..84ce7e0567 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H #define MINDSPORE_ANF_BATCHNORM_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { -class AnfBatchnormParser : public AnfNodePopulater { +class AnfBatchnormPopulater : public AnfNodePopulater { public: - AnfBatchnormParser() = default; - ~AnfBatchnormParser() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + AnfBatchnormPopulater() = default; + ~AnfBatchnormPopulater() override = default; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc similarity index 57% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc index ad59e89936..44e2a35330 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc @@ -13,25 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" +#include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); attr->axis = {0}; - - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_BiasAdd; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_BiasAdd; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater()); +AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h index 6256e20567..3fbf17ee49 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_BIASADD_PARSER_H #define MINDSPORE_ANF_BIASADD_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfBiasAddPopulater : public AnfNodePopulater { public: AnfBiasAddPopulater() = default; ~AnfBiasAddPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc similarity index 58% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc index 1b4596205a..51c52eca68 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc @@ -16,30 +16,27 @@ * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_concat_populater.h" +#include "src/common/anf_importer/anf_populater/anf_concat_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - - auto prim_axis = GetValue(p->GetAttr("axis")); + auto prim_axis = GetValue(prim->GetAttr("axis")); attr->axis = prim_axis; - - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Concat; - node->primitive->value.value = attr.release(); - + primitive->value.type = schema::PrimitiveType_Concat; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater()); +AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h similarity index 83% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h index 9a9915dcb5..aa59219f92 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h @@ -18,14 +18,15 @@ #ifndef MINDSPORE_ANF_CONCAT_PARSER_H #define MINDSPORE_ANF_CONCAT_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfConcatPopulater : public AnfNodePopulater { public: AnfConcatPopulater() = default; ~AnfConcatPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc similarity index 59% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc index c2dcee77ec..c662f832b8 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc @@ -16,23 +16,22 @@ * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" +#include "src/common/anf_importer/anf_populater/anf_conv_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); - int group = GetValue(p->GetAttr("group")); - +int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + int group = GetValue(prim->GetAttr("group")); + auto primitive = std::make_unique(); if (group > 1) { auto attr = std::make_unique(); - auto format = GetValue(p->GetAttr("data_format")); + auto format = GetValue(prim->GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; } else if (format == "NHWC") { @@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem } else { attr->format = schema::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(p->GetAttr("pad_list")); + auto pad_list = GetValue>(prim->GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(p->GetAttr("dilation")); + auto dilation = GetValue>(prim->GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(p->GetAttr("stride")); + auto stride = GetValue>(prim->GetAttr("stride")); attr->strideH = stride[2]; attr->strideW = stride[3]; - auto pad_mode = GetValue(p->GetAttr("pad_mode")); + auto pad_mode = GetValue(prim->GetAttr("pad_mode")); if (pad_mode == "valid") { attr->padMode = schema::PadMode_VALID; } else if (pad_mode == "same") { @@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem attr->padMode = schema::PadMode_NOTSET; } - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); } else { auto attr = std::make_unique(); attr->group = group; - auto format = GetValue(p->GetAttr("data_format")); + auto format = GetValue(prim->GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; } else if (format == "NHWC") { @@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem } else { attr->format = schema::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(p->GetAttr("pad_list")); + auto pad_list = GetValue>(prim->GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(p->GetAttr("dilation")); + auto dilation = GetValue>(prim->GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(p->GetAttr("stride")); + auto stride = GetValue>(prim->GetAttr("stride")); attr->strideH = stride[2]; attr->strideW = stride[3]; - attr->channelOut = GetValue(p->GetAttr("out_channel")); + attr->channelOut = GetValue(prim->GetAttr("out_channel")); - auto pad_mode = GetValue(p->GetAttr("pad_mode")); + auto pad_mode = GetValue(prim->GetAttr("pad_mode")); if (pad_mode == "valid") { attr->padMode = schema::PadMode_VALID; } else if (pad_mode == "same") { @@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem } else { attr->padMode = schema::PadMode_NOTSET; } - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Conv2D; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); } + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } - -AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater()); +AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h similarity index 83% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h index 88edda0951..5614f4c7cc 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h @@ -18,14 +18,15 @@ #ifndef MINDSPORE_ANF_CONV_PARSER_H #define MINDSPORE_ANF_CONV_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfConvPopulater : public AnfNodePopulater { public: AnfConvPopulater() = default; ~AnfConvPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc similarity index 62% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc index 8bb4c79771..b13bc6c822 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc @@ -13,21 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h" +#include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - auto format = GetValue(p->GetAttr("data_format")); + auto format = GetValue(prim->GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; } else if (format == "NHWC") { @@ -35,25 +35,25 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP } else { attr->format = schema::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(p->GetAttr("pads")); + auto pad_list = GetValue>(prim->GetAttr("pads")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(p->GetAttr("dilation")); + auto dilation = GetValue>(prim->GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(p->GetAttr("stride")); + auto stride = GetValue>(prim->GetAttr("stride")); attr->strideH = stride[2]; attr->strideW = stride[3]; - auto pad_mode = GetValue(p->GetAttr("pad_mode")); + auto pad_mode = GetValue(prim->GetAttr("pad_mode")); if (pad_mode == "valid") { attr->padMode = schema::PadMode_VALID; } else if (pad_mode == "same") { @@ -62,11 +62,11 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP attr->padMode = schema::PadMode_NOTSET; } - auto channel_multiplier = GetValue(p->GetAttr("channel_multiplier")); + auto channel_multiplier = GetValue(prim->GetAttr("channel_multiplier")); attr->channelMultiplier = channel_multiplier; - MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); - auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + MS_ASSERT(inputs.size() == kAnfPopulaterThree); + auto inputNode = inputs[kAnfPopulaterTwo]; MS_ASSERT(inputNode != nullptr); if (inputNode->isa()) { auto paramNode = inputNode->cast(); @@ -82,12 +82,12 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP } } - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); -AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h index de96776d6f..c9b63e710d 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfDepwiseconv2DPopulater : public AnfNodePopulater { public: AnfDepwiseconv2DPopulater() = default; ~AnfDepwiseconv2DPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc similarity index 57% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc index a08bf67d68..5df3a75c92 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc @@ -13,23 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h" +#include "src/common/anf_importer/anf_populater/anf_dequant_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater()); +AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h index 12017ad60b..936468d85e 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H #define MINDSPORE_ANF_DEQUANT_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfDequantPopulater : public AnfNodePopulater { public: AnfDequantPopulater() = default; ~AnfDequantPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc similarity index 56% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc index 8ba27f99a7..0e669345e6 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc @@ -13,23 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" +#include "src/common/anf_importer/anf_populater/anf_flatten_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Flatten; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Flatten; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater()); +AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h index f2cf48ab02..8ec178b213 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H #define MINDSPORE_ANF_FLATTEN_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfFlattenPopulater : public AnfNodePopulater { public: AnfFlattenPopulater() = default; ~AnfFlattenPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc similarity index 52% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc index 909ceec01a..7b5b5bb4d9 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc @@ -13,26 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" +#include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - attr->transposeA = GetValue(p->GetAttr("transpose_a")); - attr->transposeB = GetValue(p->GetAttr("transpose_b")); + attr->transposeA = GetValue(prim->GetAttr("transpose_a")); + attr->transposeB = GetValue(prim->GetAttr("transpose_b")); - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_MatMul; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_MatMul; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater()); +AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h index 752e8eff31..651b41c9d7 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_MATMUL_PARSER_H #define MINDSPORE_ANF_MATMUL_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfMatmulPopulater : public AnfNodePopulater { public: AnfMatmulPopulater() = default; ~AnfMatmulPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc similarity index 57% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc index 4f5c3beec8..7edf1a9328 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc @@ -13,23 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_mul_populater.h" +#include "src/common/anf_importer/anf_populater/anf_mul_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Mul; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Mul; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); +AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h index 87f526cf7a..2761300d46 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfMulPopulater : public AnfNodePopulater { public: AnfMulPopulater() = default; ~AnfMulPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc similarity index 91% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc index 4045e0e043..9bec531c00 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc @@ -14,6 +14,6 @@ * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" namespace mindspore::lite {} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h similarity index 84% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h index 3d9accb75e..6270a1ea18 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h @@ -19,6 +19,7 @@ #include #include "ir/anf.h" +#include "src/ir/primitive_t_value.h" #include "schema/inner/model_generated.h" namespace mindspore::lite { constexpr int kAnfPopulaterOne = 1; @@ -28,7 +29,9 @@ class AnfNodePopulater { public: AnfNodePopulater() = default; virtual ~AnfNodePopulater() = default; - virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) = 0; + + virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) = 0; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc similarity index 62% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc index 1ac99bd3f1..1d8d36bf10 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc @@ -14,14 +14,8 @@ * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include -#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" -#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" -#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" -#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" -#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" -#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" namespace mindspore { namespace lite { AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { @@ -29,13 +23,13 @@ AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { return &instance; } AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { - if (parsers.find(name) == parsers.end()) { + if (populaters.find(name) == populaters.end()) { return nullptr; } - return parsers[name]; + return populaters[name]; } -void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *parser) { - parsers[name] = parser; +void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *populater) { + populaters[name] = populater; } } // namespace lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h similarity index 88% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h index 321d4b5fb3..0d88eec3b1 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include #include namespace mindspore::lite { @@ -26,16 +26,16 @@ class AnfNodePopulaterRegistry { virtual ~AnfNodePopulaterRegistry() = default; static AnfNodePopulaterRegistry *GetInstance(); AnfNodePopulater *GetNodePopulater(const std::string &name); - void SetNodePopulater(const std::string &name, AnfNodePopulater *parser); + void SetNodePopulater(const std::string &name, AnfNodePopulater *populater); private: - std::unordered_map parsers; + std::unordered_map populaters; }; class AnfNodePopulaterRegistrar { public: - AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *parser) { - AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, parser); + AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *populater) { + AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, populater); } }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc similarity index 60% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc index 8c70bb46ae..0aa53df227 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc @@ -13,26 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" +#include "src/common/anf_importer/anf_populater/anf_pool_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - if (p->instance_name() == "MaxPool") { + if (prim->instance_name() == "MaxPool") { attr->poolingMode = schema::PoolMode_MAX_POOLING; - } else if (p->instance_name() == "MeanPool") { + } else if (prim->instance_name() == "MeanPool") { attr->poolingMode = schema::PoolMode_MEAN_POOLING; } - auto format = GetValue(p->GetAttr("data_format")); + auto format = GetValue(prim->GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; } else if (format == "NHWC") { @@ -41,7 +41,7 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem attr->format = schema::Format_NUM_OF_FORMAT; } - auto pad_mode = GetValue(p->GetAttr("padding")); + auto pad_mode = GetValue(prim->GetAttr("padding")); if (pad_mode == "VALID") { attr->padMode = schema::PadMode_VALID; } else if (pad_mode == "SAME") { @@ -50,19 +50,19 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem attr->padMode = schema::PadMode_NOTSET; } - auto kernel_size = GetValue>(p->GetAttr("ksize")); + auto kernel_size = GetValue>(prim->GetAttr("ksize")); attr->windowH = kernel_size[2]; attr->windowW = kernel_size[3]; - auto stride = GetValue>(p->GetAttr("strides")); + auto stride = GetValue>(prim->GetAttr("strides")); attr->strideH = stride[2]; attr->strideW = stride[3]; - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Pooling; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Pooling; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfMaxPoolParser("MaxPool", new AnfPoolPopulater()); +AnfNodePopulaterRegistrar anfMaxPoolPopulater("MaxPool", new AnfPoolPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h index a677e7baca..0589172505 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_POOL_PARSER_H #define MINDSPORE_ANF_POOL_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfPoolPopulater : public AnfNodePopulater { public: AnfPoolPopulater() = default; ~AnfPoolPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc similarity index 57% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc index 964f00c2a5..1f4c7e0716 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc @@ -13,23 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_quant_populater.h" +#include "src/common/anf_importer/anf_populater/anf_quant_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater()); +AnfNodePopulaterRegistrar anfQuantPopulater("Quant", new AnfQuantPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h index 87a593b459..a9aed77da6 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_QUANT_PARSER_H #define MINDSPORE_ANF_QUANT_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfQuantPopulater : public AnfNodePopulater { public: AnfQuantPopulater() = default; ~AnfQuantPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc similarity index 65% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc index 8ec0a93cfe..00bf3d7105 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_reducemean_populater.h" +#include "src/common/anf_importer/anf_populater/anf_reducemean_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" @@ -25,15 +25,15 @@ namespace { constexpr int kReduceInputNum = 3; constexpr int kReduceInputIndex = 2; } -int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { - auto p = GetCNodePrimitive(cnodePtr); +int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); attr->mode = schema::ReduceMode_ReduceMean; - attr->keepDims = GetValue(p->GetAttr("keep_dims")); - if (cnodePtr->inputs().size() == kReduceInputNum) { - auto inputNode = cnodePtr->input(kReduceInputIndex); + attr->keepDims = GetValue(prim->GetAttr("keep_dims")); + if (inputs.size() == kReduceInputNum) { + auto inputNode = inputs[kReduceInputIndex]; MS_ASSERT(inputNode != nullptr); if (inputNode->isa()) { auto valueNode = inputNode->cast(); @@ -52,11 +52,11 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN } } - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Reduce; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Reduce; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfReduceMeanParser("ReduceMean", new AnfReduceMeanPopulater()); +AnfNodePopulaterRegistrar anfReduceMeanPopulater("ReduceMean", new AnfReduceMeanPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h index 16ac3b0c7e..f82a3997f0 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfReduceMeanPopulater : public AnfNodePopulater { public: AnfReduceMeanPopulater() = default; ~AnfReduceMeanPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc similarity index 65% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc index 6669d9f11c..9baea4130b 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h" +#include "src/common/anf_importer/anf_populater/anf_reshape_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); - auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + MS_ASSERT(inputs.size() == kAnfPopulaterThree); + auto inputNode = inputs[kAnfPopulaterTwo]; if (inputNode->isa()) { auto valueNode = inputNode->cast(); MS_ASSERT(valueNode != nullptr); @@ -42,12 +43,12 @@ int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, sc } } - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Reshape; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Reshape; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater()); +AnfNodePopulaterRegistrar anfReshapePopulater("Reshape", new AnfReshapePopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h index 776aab0f94..b46d931cf9 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h @@ -16,14 +16,15 @@ #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H #define MINDSPORE_ANF_RESHAPE_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfReshapePopulater : public AnfNodePopulater { public: AnfReshapePopulater() = default; ~AnfReshapePopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc similarity index 56% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc index e220e45b41..d0bb01c4c9 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc @@ -13,22 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h" +#include "src/common/anf_importer/anf_populater/anf_tensoradd_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfTensorAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Add; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Add; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfTensorAddParser("TensorAdd", new AnfTensorAddPopulater()); +AnfNodePopulaterRegistrar anfTensorAddPopulater("TensorAdd", new AnfTensorAddPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h index d8ff59bba7..b7ecf326fb 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfTensorAddPopulater : public AnfNodePopulater { public: AnfTensorAddPopulater() = default; ~AnfTensorAddPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc similarity index 65% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc index 76eafbec64..e2c1548ff6 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc @@ -13,21 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_transpose_populater.h" +#include "src/common/anf_importer/anf_populater/anf_transpose_populater.h" #include #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - - MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); - auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + MS_ASSERT(inputs.size() == kAnfPopulaterThree); + auto inputNode = inputs[kAnfPopulaterTwo]; if (inputNode->isa()) { auto valNode = inputNode->cast(); MS_ASSERT(valNode != nullptr); @@ -44,11 +44,11 @@ int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, } } - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Transpose; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Transpose; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfTransposeParser("Transpose", new AnfTransposePopulater()); +AnfNodePopulaterRegistrar anfTransposePopulater("Transpose", new AnfTransposePopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h index eecdbb7593..583912d2b1 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H #define MINDSPORE_ANF_TRANSPOSE_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfTransposePopulater : public AnfNodePopulater { public: AnfTransposePopulater() = default; ~AnfTransposePopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc similarity index 55% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc index 9f6092f4ae..ec5e6b7433 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc @@ -13,22 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h" +#include "src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h" #include #include -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { -int mindspore::lite::AnfTupleGetItemPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, - std::vector *outputs) { +int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { + auto primitive = std::make_unique(); auto attr = std::make_unique(); - node->nodeType = schema::NodeType_CNode; - node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_TupleGetItem; - node->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_TupleGetItem; + primitive->value.value = attr.release(); + MS_ASSERT(primitiveTValuePtr != nullptr); + primitiveTValuePtr->SetPrimitiveT(primitive.release()); return 0; } -AnfNodePopulaterRegistrar anfTupleGetItemParser("tuple_getitem", new AnfTupleGetItemPopulater()); +AnfNodePopulaterRegistrar anfTupleGetItemPopulater("tuple_getitem", new AnfTupleGetItemPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h rename to mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h index 3acf2638c3..b6b256a39a 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h @@ -15,14 +15,15 @@ */ #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H #define MINDSPORE_ANF_BATCHNORM_PARSER_H -#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfTupleGetItemPopulater : public AnfNodePopulater { public: AnfTupleGetItemPopulater() = default; ~AnfTupleGetItemPopulater() override = default; - int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; + int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc index 28eae27b3a..0479c40486 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -28,6 +28,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "frontend/operator/ops.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "include/errorcode.h" @@ -38,6 +39,7 @@ #include "tools/converter/parser/onnx/onnx.pb.h" #include "utils/log_adapter.h" #include "securec/include/securec.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" using string = std::string; using int32 = int32_t; @@ -997,10 +999,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out return nullptr; } } - std::vector inputs; inputs.clear(); - inputs.push_back(NewValueNode(prim)); for (int i = 0; i < node_proto.input_size(); ++i) { const std::string &input_name = node_proto.input(i); if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { @@ -1009,6 +1009,18 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out } inputs.push_back(anfnode_build_map_[input_name]); } + std::string opType = prim->name(); + auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + if (node_parser == nullptr) { + MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; + return nullptr; + } + auto primitiveT = std::make_unique(); + // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); + std::shared_ptr primitiveTValuePtr = std::make_shared(primitiveT.release()); + node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); + MS_ASSERT(primitiveTValuePtr != nullptr); + inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); if (node_type == "LayerNorm") { diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 548dd99f9f..6c139fc9bc 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -215,9 +215,9 @@ if(BUILD_CONVERTER) ${TEST_CASE_TFLITE_PARSERS_SRC} ${TOP_DIR}/mindspore/core/utils/flags.cc ${LITE_DIR}/tools/converter/optimizer.cc - ${LITE_DIR}/src/common/anf_importer/anf_importer.cc - ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc - ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc + # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc + # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc + # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc ${LITE_DIR}/tools/converter/anf_transform.cc ${LITE_DIR}/tools/converter/graphdef_transform.cc ${LITE_DIR}/tools/converter/converter_flags.cc @@ -345,7 +345,7 @@ if (BUILD_MINDDATA) endif() if (BUILD_CONVERTER) target_link_libraries(lite-test - anf_exporter_mid + anf_importer_mid tflite_parser_mid caffe_parser_mid node_mid diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index f8cc8bcfa5..9119138545 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -70,9 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/anf_importer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_meta_graphT.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc ../optimizer/common/node_pass_extends.cc @@ -99,7 +97,7 @@ add_executable(converter_lite target_link_libraries(converter_lite PRIVATE tflite_parser_mid caffe_parser_mid - anf_exporter_mid + anf_importer_mid node_mid graph_pass_mid fusion_mid diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index 24a17962c1..1777de07ed 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(quantizer_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/common/anf_exporter/anf_exporter.cc ) if(ENABLE_ASAN)