Browse Source

adjust anf

tags/v0.7.0-beta
xuanyue 5 years ago
parent
commit
3a07af9633
45 changed files with 333 additions and 309 deletions
  1. +1
    -1
      mindspore/lite/CMakeLists.txt
  2. +22
    -22
      mindspore/lite/src/common/anf_exporter/anf_exporter.cc
  3. +2
    -2
      mindspore/lite/src/common/anf_importer/CMakeLists.txt
  4. +15
    -15
      mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc
  5. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h
  6. +11
    -12
      mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc
  7. +6
    -5
      mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h
  8. +10
    -10
      mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc
  9. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h
  10. +11
    -14
      mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc
  11. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h
  12. +26
    -29
      mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc
  13. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h
  14. +20
    -20
      mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc
  15. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h
  16. +10
    -9
      mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc
  17. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h
  18. +10
    -9
      mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc
  19. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h
  20. +12
    -12
      mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc
  21. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h
  22. +10
    -10
      mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc
  23. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h
  24. +1
    -1
      mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc
  25. +4
    -1
      mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h
  26. +5
    -11
      mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc
  27. +5
    -5
      mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h
  28. +16
    -16
      mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc
  29. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h
  30. +10
    -9
      mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc
  31. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h
  32. +13
    -13
      mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc
  33. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h
  34. +12
    -11
      mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc
  35. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h
  36. +10
    -9
      mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc
  37. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h
  38. +12
    -12
      mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc
  39. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h
  40. +10
    -9
      mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc
  41. +3
    -2
      mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h
  42. +14
    -2
      mindspore/lite/src/common/anf_importer/import_from_protobuf.cc
  43. +4
    -4
      mindspore/lite/test/CMakeLists.txt
  44. +2
    -4
      mindspore/lite/tools/converter/CMakeLists.txt
  45. +1
    -0
      mindspore/lite/tools/converter/quantizer/CMakeLists.txt

+ 1
- 1
mindspore/lite/CMakeLists.txt View File

@@ -106,7 +106,7 @@ if (BUILD_CONVERTER)
include_directories(${TOP_DIR}/third_party/protobuf/build/include)
link_directories(${TOP_DIR}/third_party/protobuf/build/lib)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
add_subdirectory(src/common/anf_exporter)
add_subdirectory(src/common/anf_importer)
endif()

if (BUILD_DEVICE)


+ 22
- 22
mindspore/lite/src/common/anf_exporter/anf_exporter.cc View File

@@ -25,7 +25,7 @@
#include "abstract/abstract_value.h"
#include "base/core_ops.h"
#include "mindspore/core/ir/primitive.h"
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
// #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/ir/primitive_t_value.h"
#include "src/ir/tensor.h"
#include "src/param_value_lite.h"
@@ -148,27 +148,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
node->name = cnode->fullname_with_scope();
node->nodeType = schema::NodeType_CNode;
// populate primitive
if (primitive != nullptr) {
primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(primitive != nullptr);
std::string opType = primitive->name();
auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
if (nodeParser == nullptr) {
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
return nullptr;
}
std::vector<schema::TensorT *> outputs;
if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
outputs.resize(abstract_cnode->size());
}
nodeParser->Parse(cnode, node.get(), &outputs);
SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
metaGraphT->nodes.emplace_back(std::move(node));
continue;
}
// if (primitive != nullptr) {
// primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
// MS_ASSERT(primitive != nullptr);
// std::string opType = primitive->name();
// auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
// if (nodeParser == nullptr) {
// MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
// return nullptr;
// }
// std::vector<schema::TensorT *> outputs;
// if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
// auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
// outputs.resize(abstract_cnode->size());
// }
//
// nodeParser->Parse(cnode, node.get(), &outputs);
// SetOpInputNode(cnode, metaGraphT.get(), node.get());
// SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
// metaGraphT->nodes.emplace_back(std::move(node));
// continue;
// }
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";


mindspore/lite/src/common/anf_exporter/CMakeLists.txt → mindspore/lite/src/common/anf_importer/CMakeLists.txt View File

@@ -1,7 +1,7 @@
file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.cc
)
add_library(anf_exporter_mid OBJECT
list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc)
add_library(anf_importer_mid OBJECT
${ANF_SRC_LIST}
)


mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc View File

@@ -13,33 +13,33 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h"
#include "src/common/anf_importer/anf_populater/anf_activation_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ActivationT>();
if (p->name() == "ReLU") {
if (prim->name() == "ReLU") {
attr->type = schema::ActivationType_RELU;
} else if (p->name() == "Sigmoid") {
} else if (prim->name() == "Sigmoid") {
attr->type = schema::ActivationType_SIGMOID;
} else if (p->name() == "ReLU6") {
} else if (prim->name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Activation;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h View File

@@ -16,14 +16,15 @@

#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfActivationPopulater : public AnfNodePopulater {
public:
AnfActivationPopulater() = default;
~AnfActivationPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc View File

@@ -13,25 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h"
#include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::FusedBatchNormT>();
attr->epsilon = GetValue<float>(p->GetAttr("epsilon"));

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
node->primitive->value.value = attr.release();
attr->epsilon = GetValue<float>(prim->GetAttr("epsilon"));
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser());
AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H
#define MINDSPORE_ANF_BATCHNORM_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfBatchnormParser : public AnfNodePopulater {
class AnfBatchnormPopulater : public AnfNodePopulater {
public:
AnfBatchnormParser() = default;
~AnfBatchnormParser() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
AnfBatchnormPopulater() = default;
~AnfBatchnormPopulater() override = default;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc View File

@@ -13,25 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h"
#include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::BiasAddT>();
attr->axis = {0};

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_BiasAdd;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_BiasAdd;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}

AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater());
AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_BIASADD_PARSER_H
#define MINDSPORE_ANF_BIASADD_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfBiasAddPopulater : public AnfNodePopulater {
public:
AnfBiasAddPopulater() = default;
~AnfBiasAddPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc View File

@@ -16,30 +16,27 @@
* limitations under the License.
*/

#include "src/common/anf_exporter/anf_populater/anf_concat_populater.h"
#include "src/common/anf_importer/anf_populater/anf_concat_populater.h"
#include <string>
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ConcatT>();

auto prim_axis = GetValue<int>(p->GetAttr("axis"));
auto prim_axis = GetValue<int>(prim->GetAttr("axis"));
attr->axis = prim_axis;

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Concat;
node->primitive->value.value = attr.release();

primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}

AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater());
AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h View File

@@ -18,14 +18,15 @@

#ifndef MINDSPORE_ANF_CONCAT_PARSER_H
#define MINDSPORE_ANF_CONCAT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfConcatPopulater : public AnfNodePopulater {
public:
AnfConcatPopulater() = default;
~AnfConcatPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc View File

@@ -16,23 +16,22 @@
* limitations under the License.
*/

#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h"
#include "src/common/anf_importer/anf_populater/anf_conv_populater.h"
#include <string>
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int group = GetValue<int>(p->GetAttr("group"));

int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
int group = GetValue<int>(prim->GetAttr("group"));
auto primitive = std::make_unique<schema::PrimitiveT>();
if (group > 1) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list"));
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];

auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];

auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];

auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];

auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
@@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
attr->padMode = schema::PadMode_NOTSET;
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
} else {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list"));
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];

auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];

auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];

auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];

attr->channelOut = GetValue<int>(p->GetAttr("out_channel"));
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));

auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
@@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else {
attr->padMode = schema::PadMode_NOTSET;
}
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Conv2D;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}

AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater());
AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h View File

@@ -18,14 +18,15 @@

#ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater {
public:
AnfConvPopulater() = default;
~AnfConvPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc View File

@@ -13,21 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h"
#include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::DepthwiseConv2DT>();

auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@@ -35,25 +35,25 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads"));
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];

auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];

auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];

auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];

auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
@@ -62,11 +62,11 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
attr->padMode = schema::PadMode_NOTSET;
}

auto channel_multiplier = GetValue<int>(p->GetAttr("channel_multiplier"));
auto channel_multiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
attr->channelMultiplier = channel_multiplier;

MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
MS_ASSERT(inputs.size() == kAnfPopulaterThree);
auto inputNode = inputs[kAnfPopulaterTwo];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
@@ -82,12 +82,12 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
}
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
public:
AnfDepwiseconv2DPopulater() = default;
~AnfDepwiseconv2DPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc View File

@@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h"
#include "src/common/anf_importer/anf_populater/anf_dequant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater());
AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H
#define MINDSPORE_ANF_DEQUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfDequantPopulater : public AnfNodePopulater {
public:
AnfDequantPopulater() = default;
~AnfDequantPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc View File

@@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h"
#include "src/common/anf_importer/anf_populater/anf_flatten_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::FlattenT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Flatten;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}

AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater());
AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_FLATTEN_PARSER_H
#define MINDSPORE_ANF_FLATTEN_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfFlattenPopulater : public AnfNodePopulater {
public:
AnfFlattenPopulater() = default;
~AnfFlattenPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc View File

@@ -13,26 +13,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h"
#include "src/common/anf_importer/anf_populater/anf_matmul_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MatMulT>();
attr->transposeA = GetValue<bool>(p->GetAttr("transpose_a"));
attr->transposeB = GetValue<bool>(p->GetAttr("transpose_b"));
attr->transposeA = GetValue<bool>(prim->GetAttr("transpose_a"));
attr->transposeB = GetValue<bool>(prim->GetAttr("transpose_b"));

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_MatMul;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater());
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_MATMUL_PARSER_H
#define MINDSPORE_ANF_MATMUL_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfMatmulPopulater : public AnfNodePopulater {
public:
AnfMatmulPopulater() = default;
~AnfMatmulPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc View File

@@ -13,23 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_mul_populater.h"
#include "src/common/anf_importer/anf_populater/anf_mul_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MulT>();

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Mul;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Mul;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater());
AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfMulPopulater : public AnfNodePopulater {
public:
AnfMulPopulater() = default;
~AnfMulPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc View File

@@ -14,6 +14,6 @@
* limitations under the License.
*/

#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"

namespace mindspore::lite {} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h View File

@@ -19,6 +19,7 @@

#include <vector>
#include "ir/anf.h"
#include "src/ir/primitive_t_value.h"
#include "schema/inner/model_generated.h"
namespace mindspore::lite {
constexpr int kAnfPopulaterOne = 1;
@@ -28,7 +29,9 @@ class AnfNodePopulater {
public:
AnfNodePopulater() = default;
virtual ~AnfNodePopulater() = default;
virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) = 0;

virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) = 0;
};

} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc View File

@@ -14,14 +14,8 @@
* limitations under the License.
*/

#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include <string>
#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h"
#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h"
#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h"
#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h"
#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h"
#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h"
namespace mindspore {
namespace lite {
AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() {
@@ -29,13 +23,13 @@ AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() {
return &instance;
}
AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) {
if (parsers.find(name) == parsers.end()) {
if (populaters.find(name) == populaters.end()) {
return nullptr;
}
return parsers[name];
return populaters[name];
}
void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *parser) {
parsers[name] = parser;
void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *populater) {
populaters[name] = populater;
}

} // namespace lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h View File

@@ -16,7 +16,7 @@

#ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H
#define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <unordered_map>
#include <string>
namespace mindspore::lite {
@@ -26,16 +26,16 @@ class AnfNodePopulaterRegistry {
virtual ~AnfNodePopulaterRegistry() = default;
static AnfNodePopulaterRegistry *GetInstance();
AnfNodePopulater *GetNodePopulater(const std::string &name);
void SetNodePopulater(const std::string &name, AnfNodePopulater *parser);
void SetNodePopulater(const std::string &name, AnfNodePopulater *populater);

private:
std::unordered_map<std::string, AnfNodePopulater *> parsers;
std::unordered_map<std::string, AnfNodePopulater *> populaters;
};

class AnfNodePopulaterRegistrar {
public:
AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *parser) {
AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, parser);
AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *populater) {
AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, populater);
}
};
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc View File

@@ -13,26 +13,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h"
#include "src/common/anf_importer/anf_populater/anf_pool_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::PoolingT>();
if (p->instance_name() == "MaxPool") {
if (prim->instance_name() == "MaxPool") {
attr->poolingMode = schema::PoolMode_MAX_POOLING;
} else if (p->instance_name() == "MeanPool") {
} else if (prim->instance_name() == "MeanPool") {
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
}

auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@@ -41,7 +41,7 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
attr->format = schema::Format_NUM_OF_FORMAT;
}

auto pad_mode = GetValue<std::string>(p->GetAttr("padding"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("padding"));
if (pad_mode == "VALID") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "SAME") {
@@ -50,19 +50,19 @@ int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
attr->padMode = schema::PadMode_NOTSET;
}

auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("ksize"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("ksize"));
attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3];

auto stride = GetValue<std::vector<int>>(p->GetAttr("strides"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("strides"));
attr->strideH = stride[2];
attr->strideW = stride[3];

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Pooling;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Pooling;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfMaxPoolParser("MaxPool", new AnfPoolPopulater());
AnfNodePopulaterRegistrar anfMaxPoolPopulater("MaxPool", new AnfPoolPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_POOL_PARSER_H
#define MINDSPORE_ANF_POOL_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfPoolPopulater : public AnfNodePopulater {
public:
AnfPoolPopulater() = default;
~AnfPoolPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc View File

@@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_quant_populater.h"
#include "src/common/anf_importer/anf_populater/anf_quant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::OnnxInt8QuantizeT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater());
AnfNodePopulaterRegistrar anfQuantPopulater("Quant", new AnfQuantPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_QUANT_PARSER_H
#define MINDSPORE_ANF_QUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfQuantPopulater : public AnfNodePopulater {
public:
AnfQuantPopulater() = default;
~AnfQuantPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc View File

@@ -13,10 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_reducemean_populater.h"
#include "src/common/anf_importer/anf_populater/anf_reducemean_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

@@ -25,15 +25,15 @@ namespace {
constexpr int kReduceInputNum = 3;
constexpr int kReduceInputIndex = 2;
}
int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ReduceT>();
attr->mode = schema::ReduceMode_ReduceMean;

attr->keepDims = GetValue<bool>(p->GetAttr("keep_dims"));
if (cnodePtr->inputs().size() == kReduceInputNum) {
auto inputNode = cnodePtr->input(kReduceInputIndex);
attr->keepDims = GetValue<bool>(prim->GetAttr("keep_dims"));
if (inputs.size() == kReduceInputNum) {
auto inputNode = inputs[kReduceInputIndex];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
@@ -52,11 +52,11 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN
}
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Reduce;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Reduce;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfReduceMeanParser("ReduceMean", new AnfReduceMeanPopulater());
AnfNodePopulaterRegistrar anfReduceMeanPopulater("ReduceMean", new AnfReduceMeanPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfReduceMeanPopulater : public AnfNodePopulater {
public:
AnfReduceMeanPopulater() = default;
~AnfReduceMeanPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc View File

@@ -13,19 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h"
#include "src/common/anf_importer/anf_populater/anf_reshape_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ReshapeT>();
MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
MS_ASSERT(inputs.size() == kAnfPopulaterThree);
auto inputNode = inputs[kAnfPopulaterTwo];
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
@@ -42,12 +43,12 @@ int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, sc
}
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Reshape;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Reshape;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}

AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater());
AnfNodePopulaterRegistrar anfReshapePopulater("Reshape", new AnfReshapePopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h View File

@@ -16,14 +16,15 @@

#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H
#define MINDSPORE_ANF_RESHAPE_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfReshapePopulater : public AnfNodePopulater {
public:
AnfReshapePopulater() = default;
~AnfReshapePopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc View File

@@ -13,22 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h"
#include "src/common/anf_importer/anf_populater/anf_tensoradd_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfTensorAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::AddT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Add;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Add;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfTensorAddParser("TensorAdd", new AnfTensorAddPopulater());
AnfNodePopulaterRegistrar anfTensorAddPopulater("TensorAdd", new AnfTensorAddPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfTensorAddPopulater : public AnfNodePopulater {
public:
AnfTensorAddPopulater() = default;
~AnfTensorAddPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc View File

@@ -13,21 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_transpose_populater.h"
#include "src/common/anf_importer/anf_populater/anf_transpose_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::TransposeT>();

MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
MS_ASSERT(inputs.size() == kAnfPopulaterThree);
auto inputNode = inputs[kAnfPopulaterTwo];
if (inputNode->isa<ValueNode>()) {
auto valNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valNode != nullptr);
@@ -44,11 +44,11 @@ int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr,
}
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Transpose;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Transpose;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfTransposeParser("Transpose", new AnfTransposePopulater());
AnfNodePopulaterRegistrar anfTransposePopulater("Transpose", new AnfTransposePopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H
#define MINDSPORE_ANF_TRANSPOSE_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfTransposePopulater : public AnfNodePopulater {
public:
AnfTransposePopulater() = default;
~AnfTransposePopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc → mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc View File

@@ -13,22 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h"
#include "src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"

namespace mindspore::lite {
int mindspore::lite::AnfTupleGetItemPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::TupleGetItemT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_TupleGetItem;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_TupleGetItem;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfTupleGetItemParser("tuple_getitem", new AnfTupleGetItemPopulater());
AnfNodePopulaterRegistrar anfTupleGetItemPopulater("tuple_getitem", new AnfTupleGetItemPopulater());
} // namespace mindspore::lite

mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h → mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h View File

@@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H
#define MINDSPORE_ANF_BATCHNORM_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfTupleGetItemPopulater : public AnfNodePopulater {
public:
AnfTupleGetItemPopulater() = default;
~AnfTupleGetItemPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite


+ 14
- 2
mindspore/lite/src/common/anf_importer/import_from_protobuf.cc View File

@@ -28,6 +28,7 @@
#include <unordered_map>
#include <vector>

#include "schema/inner/model_generated.h"
#include "frontend/operator/ops.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "include/errorcode.h"
@@ -38,6 +39,7 @@
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"

using string = std::string;
using int32 = int32_t;
@@ -997,10 +999,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
return nullptr;
}
}

std::vector<AnfNodePtr> inputs;
inputs.clear();
inputs.push_back(NewValueNode(prim));
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
@@ -1009,6 +1009,18 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
}
inputs.push_back(anfnode_build_map_[input_name]);
}
std::string opType = prim->name();
auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
if (node_parser == nullptr) {
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
return nullptr;
}
auto primitiveT = std::make_unique<schema::PrimitiveT>();
// auto * primitiveTValue = new PrimitiveTValue(primitiveT.release());
std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release());
node_parser->Populate(prim, primitiveTValuePtr.get(), inputs);
MS_ASSERT(primitiveTValuePtr != nullptr);
inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr));
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (node_type == "LayerNorm") {


+ 4
- 4
mindspore/lite/test/CMakeLists.txt View File

@@ -215,9 +215,9 @@ if(BUILD_CONVERTER)
${TEST_CASE_TFLITE_PARSERS_SRC}
${TOP_DIR}/mindspore/core/utils/flags.cc
${LITE_DIR}/tools/converter/optimizer.cc
${LITE_DIR}/src/common/anf_importer/anf_importer.cc
${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc
${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc
# ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc
# ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc
# ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc
${LITE_DIR}/tools/converter/anf_transform.cc
${LITE_DIR}/tools/converter/graphdef_transform.cc
${LITE_DIR}/tools/converter/converter_flags.cc
@@ -345,7 +345,7 @@ if (BUILD_MINDDATA)
endif()
if (BUILD_CONVERTER)
target_link_libraries(lite-test
anf_exporter_mid
anf_importer_mid
tflite_parser_mid
caffe_parser_mid
node_mid


+ 2
- 4
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -70,9 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/anf_importer.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_meta_graphT.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc

../optimizer/common/node_pass_extends.cc
@@ -99,7 +97,7 @@ add_executable(converter_lite
target_link_libraries(converter_lite PRIVATE
tflite_parser_mid
caffe_parser_mid
anf_exporter_mid
anf_importer_mid
node_mid
graph_pass_mid
fusion_mid


+ 1
- 0
mindspore/lite/tools/converter/quantizer/CMakeLists.txt View File

@@ -10,6 +10,7 @@ add_library(quantizer_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../src/common/anf_exporter/anf_exporter.cc
)

if(ENABLE_ASAN)


Loading…
Cancel
Save