| @@ -106,7 +106,7 @@ if (BUILD_CONVERTER) | |||
| include_directories(${TOP_DIR}/third_party/protobuf/build/include) | |||
| link_directories(${TOP_DIR}/third_party/protobuf/build/lib) | |||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) | |||
| add_subdirectory(src/common/anf_exporter) | |||
| add_subdirectory(src/common/anf_importer) | |||
| endif() | |||
| if (BUILD_DEVICE) | |||
| @@ -25,7 +25,7 @@ | |||
| #include "abstract/abstract_value.h" | |||
| #include "base/core_ops.h" | |||
| #include "mindspore/core/ir/primitive.h" | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| // #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/ir/primitive_t_value.h" | |||
| #include "src/ir/tensor.h" | |||
| #include "src/param_value_lite.h" | |||
| @@ -148,27 +148,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||
| node->name = cnode->fullname_with_scope(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| // populate primitive | |||
| if (primitive != nullptr) { | |||
| primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(primitive != nullptr); | |||
| std::string opType = primitive->name(); | |||
| auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||
| if (nodeParser == nullptr) { | |||
| MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||
| return nullptr; | |||
| } | |||
| std::vector<schema::TensorT *> outputs; | |||
| if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { | |||
| auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); | |||
| outputs.resize(abstract_cnode->size()); | |||
| } | |||
| nodeParser->Parse(cnode, node.get(), &outputs); | |||
| SetOpInputNode(cnode, metaGraphT.get(), node.get()); | |||
| SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); | |||
| metaGraphT->nodes.emplace_back(std::move(node)); | |||
| continue; | |||
| } | |||
| // if (primitive != nullptr) { | |||
| // primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| // MS_ASSERT(primitive != nullptr); | |||
| // std::string opType = primitive->name(); | |||
| // auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||
| // if (nodeParser == nullptr) { | |||
| // MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||
| // return nullptr; | |||
| // } | |||
| // std::vector<schema::TensorT *> outputs; | |||
| // if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { | |||
| // auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); | |||
| // outputs.resize(abstract_cnode->size()); | |||
| // } | |||
| // | |||
| // nodeParser->Parse(cnode, node.get(), &outputs); | |||
| // SetOpInputNode(cnode, metaGraphT.get(), node.get()); | |||
| // SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); | |||
| // metaGraphT->nodes.emplace_back(std::move(node)); | |||
| // continue; | |||
| // } | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||
| @@ -1,7 +1,7 @@ | |||
| file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| *.cc | |||
| ) | |||
| add_library(anf_exporter_mid OBJECT | |||
| list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc) | |||
| add_library(anf_importer_mid OBJECT | |||
| ${ANF_SRC_LIST} | |||
| ) | |||
| @@ -13,33 +13,33 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_activation_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::ActivationT>(); | |||
| if (p->name() == "ReLU") { | |||
| if (prim->name() == "ReLU") { | |||
| attr->type = schema::ActivationType_RELU; | |||
| } else if (p->name() == "Sigmoid") { | |||
| } else if (prim->name() == "Sigmoid") { | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| } else if (p->name() == "ReLU6") { | |||
| } else if (prim->name() == "ReLU6") { | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Activation; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Activation; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); | |||
| AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater()); | |||
| AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); | |||
| AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater()); | |||
| AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater()); | |||
| AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -16,14 +16,15 @@ | |||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfActivationPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfActivationPopulater() = default; | |||
| ~AnfActivationPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,25 +13,24 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::FusedBatchNormT>(); | |||
| attr->epsilon = GetValue<float>(p->GetAttr("epsilon")); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||
| node->primitive->value.value = attr.release(); | |||
| attr->epsilon = GetValue<float>(prim->GetAttr("epsilon")); | |||
| primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser()); | |||
| AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | |||
| #define MINDSPORE_ANF_BATCHNORM_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfBatchnormParser : public AnfNodePopulater { | |||
| class AnfBatchnormPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfBatchnormParser() = default; | |||
| ~AnfBatchnormParser() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| AnfBatchnormPopulater() = default; | |||
| ~AnfBatchnormPopulater() override = default; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,25 +13,25 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::BiasAddT>(); | |||
| attr->axis = {0}; | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_BiasAdd; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_BiasAdd; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater()); | |||
| AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_BIASADD_PARSER_H | |||
| #define MINDSPORE_ANF_BIASADD_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfBiasAddPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfBiasAddPopulater() = default; | |||
| ~AnfBiasAddPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -16,30 +16,27 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_concat_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_concat_populater.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::ConcatT>(); | |||
| auto prim_axis = GetValue<int>(p->GetAttr("axis")); | |||
| auto prim_axis = GetValue<int>(prim->GetAttr("axis")); | |||
| attr->axis = prim_axis; | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Concat; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Concat; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater()); | |||
| AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -18,14 +18,15 @@ | |||
| #ifndef MINDSPORE_ANF_CONCAT_PARSER_H | |||
| #define MINDSPORE_ANF_CONCAT_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfConcatPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfConcatPopulater() = default; | |||
| ~AnfConcatPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -16,23 +16,22 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_conv_populater.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int group = GetValue<int>(p->GetAttr("group")); | |||
| int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| int group = GetValue<int>(prim->GetAttr("group")); | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (group > 1) { | |||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| @@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list")); | |||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); | |||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); | |||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); | |||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); | |||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| @@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| primitive->value.value = attr.release(); | |||
| } else { | |||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||
| attr->group = group; | |||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| @@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list")); | |||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); | |||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); | |||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); | |||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| attr->channelOut = GetValue<int>(p->GetAttr("out_channel")); | |||
| attr->channelOut = GetValue<int>(prim->GetAttr("out_channel")); | |||
| auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); | |||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| @@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater()); | |||
| AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -18,14 +18,15 @@ | |||
| #ifndef MINDSPORE_ANF_CONV_PARSER_H | |||
| #define MINDSPORE_ANF_CONV_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfConvPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfConvPopulater() = default; | |||
| ~AnfConvPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,21 +13,21 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| @@ -35,25 +35,25 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads")); | |||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); | |||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); | |||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); | |||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); | |||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| @@ -62,11 +62,11 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| auto channel_multiplier = GetValue<int>(p->GetAttr("channel_multiplier")); | |||
| auto channel_multiplier = GetValue<int>(prim->GetAttr("channel_multiplier")); | |||
| attr->channelMultiplier = channel_multiplier; | |||
| MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); | |||
| auto inputNode = cnodePtr->input(kAnfPopulaterTwo); | |||
| MS_ASSERT(inputs.size() == kAnfPopulaterThree); | |||
| auto inputNode = inputs[kAnfPopulaterTwo]; | |||
| MS_ASSERT(inputNode != nullptr); | |||
| if (inputNode->isa<Parameter>()) { | |||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||
| @@ -82,12 +82,12 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP | |||
| } | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||
| AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||
| AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||
| AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | |||
| #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfDepwiseconv2DPopulater() = default; | |||
| ~AnfDepwiseconv2DPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,23 +13,24 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_dequant_populater.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::OnnxInt8DequantizeT>(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater()); | |||
| AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H | |||
| #define MINDSPORE_ANF_DEQUANT_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfDequantPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfDequantPopulater() = default; | |||
| ~AnfDequantPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,23 +13,24 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_flatten_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::FlattenT>(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Flatten; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Flatten; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater()); | |||
| AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H | |||
| #define MINDSPORE_ANF_FLATTEN_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfFlattenPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfFlattenPopulater() = default; | |||
| ~AnfFlattenPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,26 +13,26 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::MatMulT>(); | |||
| attr->transposeA = GetValue<bool>(p->GetAttr("transpose_a")); | |||
| attr->transposeB = GetValue<bool>(p->GetAttr("transpose_b")); | |||
| attr->transposeA = GetValue<bool>(prim->GetAttr("transpose_a")); | |||
| attr->transposeB = GetValue<bool>(prim->GetAttr("transpose_b")); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_MatMul; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater()); | |||
| AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_MATMUL_PARSER_H | |||
| #define MINDSPORE_ANF_MATMUL_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfMatmulPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfMatmulPopulater() = default; | |||
| ~AnfMatmulPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,23 +13,23 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_mul_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_mul_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::MulT>(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Mul; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Mul; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); | |||
| AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfMulPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfMulPopulater() = default; | |||
| ~AnfMulPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -14,6 +14,6 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| namespace mindspore::lite {} // namespace mindspore::lite | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "src/ir/primitive_t_value.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore::lite { | |||
| constexpr int kAnfPopulaterOne = 1; | |||
| @@ -28,7 +29,9 @@ class AnfNodePopulater { | |||
| public: | |||
| AnfNodePopulater() = default; | |||
| virtual ~AnfNodePopulater() = default; | |||
| virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) = 0; | |||
| virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) = 0; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -14,14 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include <string> | |||
| #include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" | |||
| #include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" | |||
| #include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" | |||
| #include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" | |||
| #include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" | |||
| #include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { | |||
| @@ -29,13 +23,13 @@ AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { | |||
| return &instance; | |||
| } | |||
| AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { | |||
| if (parsers.find(name) == parsers.end()) { | |||
| if (populaters.find(name) == populaters.end()) { | |||
| return nullptr; | |||
| } | |||
| return parsers[name]; | |||
| return populaters[name]; | |||
| } | |||
| void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *parser) { | |||
| parsers[name] = parser; | |||
| void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *populater) { | |||
| populaters[name] = populater; | |||
| } | |||
| } // namespace lite | |||
| @@ -16,7 +16,7 @@ | |||
| #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | |||
| #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <unordered_map> | |||
| #include <string> | |||
| namespace mindspore::lite { | |||
| @@ -26,16 +26,16 @@ class AnfNodePopulaterRegistry { | |||
| virtual ~AnfNodePopulaterRegistry() = default; | |||
| static AnfNodePopulaterRegistry *GetInstance(); | |||
| AnfNodePopulater *GetNodePopulater(const std::string &name); | |||
| void SetNodePopulater(const std::string &name, AnfNodePopulater *parser); | |||
| void SetNodePopulater(const std::string &name, AnfNodePopulater *populater); | |||
| private: | |||
| std::unordered_map<std::string, AnfNodePopulater *> parsers; | |||
| std::unordered_map<std::string, AnfNodePopulater *> populaters; | |||
| }; | |||
| class AnfNodePopulaterRegistrar { | |||
| public: | |||
| AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *parser) { | |||
| AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, parser); | |||
| AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *populater) { | |||
| AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, populater); | |||
| } | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,26 +13,26 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_pool_populater.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::PoolingT>(); | |||
| if (p->instance_name() == "MaxPool") { | |||
| if (prim->instance_name() == "MaxPool") { | |||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||
| } else if (p->instance_name() == "MeanPool") { | |||
| } else if (prim->instance_name() == "MeanPool") { | |||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||
| } | |||
| auto format = GetValue<std::string>(p->GetAttr("data_format")); | |||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| @@ -41,7 +41,7 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_mode = GetValue<std::string>(p->GetAttr("padding")); | |||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("padding")); | |||
| if (pad_mode == "VALID") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "SAME") { | |||
| @@ -50,19 +50,19 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("ksize")); | |||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("ksize")); | |||
| attr->windowH = kernel_size[2]; | |||
| attr->windowW = kernel_size[3]; | |||
| auto stride = GetValue<std::vector<int>>(p->GetAttr("strides")); | |||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("strides")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Pooling; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Pooling; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfMaxPoolParser("MaxPool", new AnfPoolPopulater()); | |||
| AnfNodePopulaterRegistrar anfMaxPoolPopulater("MaxPool", new AnfPoolPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_POOL_PARSER_H | |||
| #define MINDSPORE_ANF_POOL_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfPoolPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfPoolPopulater() = default; | |||
| ~AnfPoolPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,23 +13,24 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_quant_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_quant_populater.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::OnnxInt8QuantizeT>(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater()); | |||
| AnfNodePopulaterRegistrar anfQuantPopulater("Quant", new AnfQuantPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_QUANT_PARSER_H | |||
| #define MINDSPORE_ANF_QUANT_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfQuantPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfQuantPopulater() = default; | |||
| ~AnfQuantPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,10 +13,10 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_reducemean_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_reducemean_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| @@ -25,15 +25,15 @@ namespace { | |||
| constexpr int kReduceInputNum = 3; | |||
| constexpr int kReduceInputIndex = 2; | |||
| } | |||
| int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| auto p = GetCNodePrimitive(cnodePtr); | |||
| int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::ReduceT>(); | |||
| attr->mode = schema::ReduceMode_ReduceMean; | |||
| attr->keepDims = GetValue<bool>(p->GetAttr("keep_dims")); | |||
| if (cnodePtr->inputs().size() == kReduceInputNum) { | |||
| auto inputNode = cnodePtr->input(kReduceInputIndex); | |||
| attr->keepDims = GetValue<bool>(prim->GetAttr("keep_dims")); | |||
| if (inputs.size() == kReduceInputNum) { | |||
| auto inputNode = inputs[kReduceInputIndex]; | |||
| MS_ASSERT(inputNode != nullptr); | |||
| if (inputNode->isa<ValueNode>()) { | |||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | |||
| @@ -52,11 +52,11 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN | |||
| } | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Reduce; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Reduce; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfReduceMeanParser("ReduceMean", new AnfReduceMeanPopulater()); | |||
| AnfNodePopulaterRegistrar anfReduceMeanPopulater("ReduceMean", new AnfReduceMeanPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfReduceMeanPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfReduceMeanPopulater() = default; | |||
| ~AnfReduceMeanPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,19 +13,20 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_reshape_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::ReshapeT>(); | |||
| MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); | |||
| auto inputNode = cnodePtr->input(kAnfPopulaterTwo); | |||
| MS_ASSERT(inputs.size() == kAnfPopulaterThree); | |||
| auto inputNode = inputs[kAnfPopulaterTwo]; | |||
| if (inputNode->isa<ValueNode>()) { | |||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | |||
| MS_ASSERT(valueNode != nullptr); | |||
| @@ -42,12 +43,12 @@ int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, sc | |||
| } | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Reshape; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater()); | |||
| AnfNodePopulaterRegistrar anfReshapePopulater("Reshape", new AnfReshapePopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -16,14 +16,15 @@ | |||
| #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H | |||
| #define MINDSPORE_ANF_RESHAPE_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfReshapePopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfReshapePopulater() = default; | |||
| ~AnfReshapePopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,22 +13,23 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_tensoradd_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfTensorAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::AddT>(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Add; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Add; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfTensorAddParser("TensorAdd", new AnfTensorAddPopulater()); | |||
| AnfNodePopulaterRegistrar anfTensorAddPopulater("TensorAdd", new AnfTensorAddPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfTensorAddPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfTensorAddPopulater() = default; | |||
| ~AnfTensorAddPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,21 +13,21 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_transpose_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_transpose_populater.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::TransposeT>(); | |||
| MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); | |||
| auto inputNode = cnodePtr->input(kAnfPopulaterTwo); | |||
| MS_ASSERT(inputs.size() == kAnfPopulaterThree); | |||
| auto inputNode = inputs[kAnfPopulaterTwo]; | |||
| if (inputNode->isa<ValueNode>()) { | |||
| auto valNode = inputNode->cast<ValueNodePtr>(); | |||
| MS_ASSERT(valNode != nullptr); | |||
| @@ -44,11 +44,11 @@ int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, | |||
| } | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Transpose; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfTransposeParser("Transpose", new AnfTransposePopulater()); | |||
| AnfNodePopulaterRegistrar anfTransposePopulater("Transpose", new AnfTransposePopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H | |||
| #define MINDSPORE_ANF_TRANSPOSE_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfTransposePopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfTransposePopulater() = default; | |||
| ~AnfTransposePopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -13,22 +13,23 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore::lite { | |||
| int mindspore::lite::AnfTupleGetItemPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, | |||
| std::vector<schema::TensorT *> *outputs) { | |||
| int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| auto attr = std::make_unique<schema::TupleGetItemT>(); | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| node->primitive->value.type = schema::PrimitiveType_TupleGetItem; | |||
| node->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_TupleGetItem; | |||
| primitive->value.value = attr.release(); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| return 0; | |||
| } | |||
| AnfNodePopulaterRegistrar anfTupleGetItemParser("tuple_getitem", new AnfTupleGetItemPopulater()); | |||
| AnfNodePopulaterRegistrar anfTupleGetItemPopulater("tuple_getitem", new AnfTupleGetItemPopulater()); | |||
| } // namespace mindspore::lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | |||
| #define MINDSPORE_ANF_BATCHNORM_PARSER_H | |||
| #include "src/common/anf_exporter/anf_populater/anf_node_populater.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| class AnfTupleGetItemPopulater : public AnfNodePopulater { | |||
| public: | |||
| AnfTupleGetItemPopulater() = default; | |||
| ~AnfTupleGetItemPopulater() override = default; | |||
| int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; | |||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||
| const std::vector<AnfNodePtr> &inputs) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -28,6 +28,7 @@ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||
| #include "include/errorcode.h" | |||
| @@ -38,6 +39,7 @@ | |||
| #include "tools/converter/parser/onnx/onnx.pb.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "securec/include/securec.h" | |||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||
| using string = std::string; | |||
| using int32 = int32_t; | |||
| @@ -997,10 +999,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.push_back(NewValueNode(prim)); | |||
| for (int i = 0; i < node_proto.input_size(); ++i) { | |||
| const std::string &input_name = node_proto.input(i); | |||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | |||
| @@ -1009,6 +1009,18 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||
| } | |||
| inputs.push_back(anfnode_build_map_[input_name]); | |||
| } | |||
| std::string opType = prim->name(); | |||
| auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||
| return nullptr; | |||
| } | |||
| auto primitiveT = std::make_unique<schema::PrimitiveT>(); | |||
| // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); | |||
| std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release()); | |||
| node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); | |||
| MS_ASSERT(primitiveTValuePtr != nullptr); | |||
| inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); | |||
| CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||
| if (node_type == "LayerNorm") { | |||
| @@ -215,9 +215,9 @@ if(BUILD_CONVERTER) | |||
| ${TEST_CASE_TFLITE_PARSERS_SRC} | |||
| ${TOP_DIR}/mindspore/core/utils/flags.cc | |||
| ${LITE_DIR}/tools/converter/optimizer.cc | |||
| ${LITE_DIR}/src/common/anf_importer/anf_importer.cc | |||
| ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc | |||
| ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||
| # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||
| # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc | |||
| # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||
| ${LITE_DIR}/tools/converter/anf_transform.cc | |||
| ${LITE_DIR}/tools/converter/graphdef_transform.cc | |||
| ${LITE_DIR}/tools/converter/converter_flags.cc | |||
| @@ -345,7 +345,7 @@ if (BUILD_MINDDATA) | |||
| endif() | |||
| if (BUILD_CONVERTER) | |||
| target_link_libraries(lite-test | |||
| anf_exporter_mid | |||
| anf_importer_mid | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| node_mid | |||
| @@ -70,9 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/anf_importer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_meta_graphT.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc | |||
| ../optimizer/common/node_pass_extends.cc | |||
| @@ -99,7 +97,7 @@ add_executable(converter_lite | |||
| target_link_libraries(converter_lite PRIVATE | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| anf_exporter_mid | |||
| anf_importer_mid | |||
| node_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| @@ -10,6 +10,7 @@ add_library(quantizer_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/common/anf_exporter/anf_exporter.cc | |||
| ) | |||
| if(ENABLE_ASAN) | |||