Browse Source

repair onnx parser

tags/v0.7.0-beta
wangzhe 5 years ago
parent
commit
dc29dd472b
83 changed files with 451 additions and 353 deletions
  1. +4
    -0
      mindspore/lite/src/ops/reshape.cc
  2. +1
    -1
      mindspore/lite/test/CMakeLists.txt
  3. +2
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  4. +5
    -0
      mindspore/lite/tools/converter/converter.cc
  5. +3
    -3
      mindspore/lite/tools/converter/converter_flags.cc
  6. +6
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc
  7. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
  8. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc
  9. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h
  10. +78
    -49
      mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc
  11. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h
  12. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc
  13. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h
  14. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc
  15. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h
  16. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc
  17. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h
  18. +16
    -9
      mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc
  19. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h
  20. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc
  21. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h
  22. +4
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc
  23. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h
  24. +23
    -19
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc
  25. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h
  26. +2
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc
  27. +2
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_converter.h
  28. +2
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc
  29. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h
  30. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc
  31. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h
  32. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc
  33. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h
  34. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc
  35. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h
  36. +5
    -5
      mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc
  37. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h
  38. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc
  39. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h
  40. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc
  41. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h
  42. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc
  43. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h
  44. +5
    -6
      mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc
  45. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h
  46. +79
    -65
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  47. +10
    -16
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  48. +1
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
  49. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h
  50. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc
  51. +1
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h
  52. +6
    -5
      mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc
  53. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h
  54. +5
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc
  55. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h
  56. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc
  57. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h
  58. +28
    -20
      mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc
  59. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h
  60. +15
    -16
      mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc
  61. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h
  62. +4
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc
  63. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h
  64. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc
  65. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h
  66. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc
  67. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h
  68. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc
  69. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h
  70. +5
    -6
      mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc
  71. +4
    -4
      mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h
  72. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc
  73. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h
  74. +4
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc
  75. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h
  76. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc
  77. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h
  78. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc
  79. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h
  80. +3
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc
  81. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h
  82. +2
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc
  83. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h

+ 4
- 0
mindspore/lite/src/ops/reshape.cc View File

@@ -102,6 +102,10 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt64: {
auto data = reinterpret_cast<int64_t *>(shape_tensor->Data());
CalShape<int64_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->Data());
CalShape<float>(data, inputs_, &out_shape, shape_size);


+ 1
- 1
mindspore/lite/test/CMakeLists.txt View File

@@ -223,7 +223,6 @@ if(BUILD_CONVERTER)
${LITE_DIR}/tools/converter/graphdef_transform.cc
${LITE_DIR}/tools/converter/converter_flags.cc
${LITE_DIR}/tools/converter/converter.cc
${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc
${LITE_DIR}/test/st/converter_test.cc
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc
@@ -351,6 +350,7 @@ if (BUILD_CONVERTER)
anf_importer_mid
tflite_parser_mid
caffe_parser_mid
onnx_parser_mid
node_mid
graph_pass_mid
fusion_mid


+ 2
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -71,7 +71,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc

../optimizer/common/node_pass_extends.cc
../optimizer/common/pass_manager_extends.cc
@@ -86,6 +85,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}

add_subdirectory(parser/caffe)
add_subdirectory(parser/tflite)
add_subdirectory(parser/onnx)
add_subdirectory(legacy_optimizer)
add_subdirectory(quantizer)

@@ -98,6 +98,7 @@ add_executable(converter_lite
target_link_libraries(converter_lite PRIVATE
tflite_parser_mid
caffe_parser_mid
onnx_parser_mid
anf_importer_mid
node_mid
graph_pass_mid


+ 5
- 0
mindspore/lite/tools/converter/converter.cc View File

@@ -27,6 +27,7 @@
#include "tools/common/storage.h"
#include "parser/caffe/caffe_converter.h"
#include "parser/tflite/tflite_converter.h"
#include "parser/onnx/onnx_converter.h"
#include "src/common/anf_exporter/anf_exporter.h"
#include "src/common/anf_importer/import_from_protobuf.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
@@ -185,6 +186,10 @@ int RunConverter(int argc, const char **argv) {
TfliteConverter tfLiteConverter;
fb_graph = tfLiteConverter.Convert(flags);
} break;
case FmkType::FmkType_ONNX: {
OnnxConverter onnxConverter;
fb_graph = onnxConverter.Convert(flags);
} break;
default: {
MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk;
return 1;


+ 3
- 3
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -14,13 +14,11 @@
* limitations under the License.
*/


#include "tools/converter/converter_flags.h"
#include <regex>
#include <string>
#include "ir/dtype/type_id.h"


namespace mindspore {
namespace lite {
namespace converter {
@@ -89,8 +87,10 @@ int Flags::Init(int argc, const char **argv) {
this->fmk = FmkType_MS;
} else if (this->fmkIn == "TFLITE") {
this->fmk = FmkType_TFLITE;
} else if (this->fmkIn == "ONNX") {
this->fmk = FmkType_ONNX;
} else {
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS";
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX";
return 1;
}



+ 6
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc View File

@@ -138,6 +138,12 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
}
beforeNodeType = kNCHW2NHWC;
afterNodeType = kNHWC2NCHW;
} else if (fmkType == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
continue;
}
beforeNodeType = kNCHW2NHWC;
afterNodeType = kNHWC2NCHW;
} else {
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType;
return RET_ERROR;
@@ -197,4 +203,3 @@ void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkTy

} // namespace lite
} // namespace mindspore


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -189,7 +189,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CKHW;
weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_CKHW;
} else {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h"
#include "tools/converter/parser/onnx/onnx_argmax_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
MS_LOG(DEBUG) << "onnx ArgMaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 78
- 49
mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc View File

@@ -15,111 +15,118 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h"
#include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AddParser";
if (op != nullptr) {
std::unique_ptr<schema::AddT> attr(new schema::AddT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SubParser";
if (op != nullptr) {
std::unique_ptr<schema::SubT> attr(new schema::SubT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MulParser";
if (op != nullptr) {
std::unique_ptr<schema::MulT> attr(new schema::MulT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DivParser";
if (op != nullptr) {
std::unique_ptr<schema::DivT> attr(new schema::DivT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_RealDiv;
op->primitive->value.value = nullptr;
}
return RET_OK;
}

STATUS OnnxMeanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Mean;
op->primitive->value.value = nullptr;
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PowParser";
if (op != nullptr) {
// TODO(wangzhe) attr power need populate
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EqualParser";
if (op != nullptr) {
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LessParser";
if (op != nullptr) {
std::unique_ptr<schema::LessT> attr(new schema::LessT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx GreaterParser";
if (op != nullptr) {
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MinParser";
if (op != nullptr) {
std::unique_ptr<schema::MinT> attr(new schema::MinT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Min;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EltwiseParser";
std::unique_ptr<schema::EltwiseT> attr(new schema::EltwiseT());
if (onnx_node.op_type() == "Prod") {
attr->mode = schema::EltwiseMode_PROD;
} else if (onnx_node.op_type() == "Sum") {
// there is no Prod in onnx
if (onnx_node.op_type() == "Sum") {
attr->mode = schema::EltwiseMode_SUM;
} else if (onnx_node.op_type() == "Maximum") {
} else if (onnx_node.op_type() == "Max") {
attr->mode = schema::EltwiseMode_MAXIMUM;
}

@@ -131,109 +138,133 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph,
return RET_OK;
}

STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx FloorParser";
if (op != nullptr) {
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AbsParser";
if (op != nullptr) {
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx NegParser";
if (op != nullptr) {
std::unique_ptr<schema::NegT> attr(new schema::NegT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ExpParser";
if (op != nullptr) {
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CosParser";
if (op != nullptr) {
std::unique_ptr<schema::CosT> attr(new schema::CosT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SinParser";
if (op != nullptr) {
std::unique_ptr<schema::SinT> attr(new schema::SinT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SqrtParser";
if (op != nullptr) {
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CeilParser";
if (op != nullptr) {
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LogParser";
if (op != nullptr) {
std::unique_ptr<schema::LogT> attr(new schema::LogT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanParser";
if (op != nullptr) {
std::unique_ptr<schema::TanT> attr(new schema::TanT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Tan;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AtanParser";
if (op != nullptr) {
std::unique_ptr<schema::AtanT> attr(new schema::AtanT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Atan;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AsinParser";
if (op != nullptr) {
std::unique_ptr<schema::AsinT> attr(new schema::AsinT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Asin;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanhParser";
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.value = nullptr;
MS_LOG(ERROR) << "mslite don't support tanh now";
return RET_ERROR;
}
return RET_OK;
}
@@ -243,13 +274,12 @@ OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser());
OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser());
OnnxNodeRegistrar g_onnxMeanParser("Mean", new OnnxMeanParser());
// OnnxNodeRegistrar g_onnxMeanParser("Mean", new OnnxMeanParser()); // onnx's Mean is different from mslite's
OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser());
OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser());
OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser());
OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser());
OnnxNodeRegistrar g_onnxMinParser("Min", new OnnxMinParser());
OnnxNodeRegistrar g_onnxProdParser("Prod", new OnnxEltwiseParser());
OnnxNodeRegistrar g_onnxSumParser("Sum", new OnnxEltwiseParser());
OnnxNodeRegistrar g_onnxMaxParser("Max", new OnnxEltwiseParser());
OnnxNodeRegistrar g_onnxFloorParser("Floor", new OnnxFloorParser());
@@ -267,4 +297,3 @@ OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc View File

@@ -14,14 +14,15 @@
* limitations under the License.
*/

#include "mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h"
#include "tools/converter/parser/onnx/onnx_batchnorm_parser.h"
#include <memory>

namespace mindspore {
namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::FusedBatchNormT> attr(new schema::FusedBatchNormT());
MS_LOG(DEBUG) << "onnx BatchNormParser";
std::unique_ptr<schema::FusedBatchNormT> attr(new schema::FusedBatchNormT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "epsilon") {
attr->epsilon = onnx_node_attr.f();


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_ADD_PARSER_H
#define MS_ONNX_ADD_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc View File

@@ -15,7 +15,7 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h"
#include "tools/converter/parser/onnx/onnx_biasadd_parser.h"

// using namespace mindspore::predict;
// using namespace onnx;
@@ -25,7 +25,8 @@ namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::BiasAddT> attr(new schema::BiasAddT());
MS_LOG(DEBUG) << "onnx BiasAddParser";
std::unique_ptr<schema::BiasAddT> attr(new schema::BiasAddT());
// use channel dim as axis
attr->axis = {1};
if (op != nullptr) {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_BIASADD_PARSER_H
#define MS_ONNX_BIASADD_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc View File

@@ -15,12 +15,13 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h"
#include "tools/converter/parser/onnx/onnx_cast_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::CastT> attr(new schema::CastT());
MS_LOG(DEBUG) << "onnx CastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_CAST_PARSER_H
#define MS_ONNX_CAST_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 16
- 9
mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc View File

@@ -15,24 +15,32 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h"
#include "tools/converter/parser/onnx/onnx_clip_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::ClipT> attr(new schema::ClipT());
MS_LOG(DEBUG) << "onnx ClipParser";
float min = -1, max = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "max") {
attr->max = onnx_node_attr.f();
max = onnx_node_attr.f();
} else if (attribute_name == "min") {
attr->min = onnx_node_attr.f();
min = onnx_node_attr.f();
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();
if (min == 0 && max == 6) {
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
attr->type = schema::ActivationType_RELU6;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
} else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
return RET_PARAM_INVALID;
}
return RET_OK;
}
@@ -40,4 +48,3 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_CLIP_PARSER_H
#define MS_ONNX_CLIP_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h"
#include "tools/converter/parser/onnx/onnx_concat_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
MS_LOG(DEBUG) << "onnx ConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_CONCAT_PARSER_H
#define MS_ONNX_CONCAT_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 4
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc View File

@@ -15,17 +15,19 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h"
#include "tools/converter/parser/onnx/onnx_constant_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConstantParser";
if (op != nullptr) {
std::unique_ptr<schema::ConstantT> attr(new schema::ConstantT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_CONSTANT_PARSER_H
#define MS_ONNX_CONSTANT_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 23
- 19
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc View File

@@ -17,17 +17,18 @@
#include <vector>
#include <memory>
#include <algorithm>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h"
#include "tools/converter/parser/onnx/onnx_conv_parser.h"

namespace mindspore {
namespace lite {
bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
if (attr == nullptr || attr->group != attr->channelIn) {
return false;
}
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT());
if (depthwiseConv2DParam == nullptr) {
// MS_LOGW("new DepthwiseConv2DT failed");
MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
return false;
}
depthwiseConv2DParam->format = attr->format;
@@ -48,12 +49,12 @@ bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *
depthwiseConv2DParam->activationType = attr->activationType;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
delete (op->primitive->value.value);
op->primitive->value.value = depthwiseConv2DParam.release();
return true;
}

STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConvParser";
auto attr = new schema::Conv2DT();
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@@ -61,30 +62,32 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
// TODO(wangzhe) verify the change
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1));
// TODO(wangzhe) verify the change
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "auto_pad") {
attr->padMode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
@@ -93,16 +96,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
// TODO(wangzhe) verify the change
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "order") {
if (onnx_node_attr.s() == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s();
return RET_ERROR;
}
}
@@ -114,7 +118,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) {
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str())
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight;
return RET_ERROR;
}
std::vector<int> weight_shape;
@@ -129,7 +133,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
if (nodeIter == onnx_graph.node().end()) {
// MS_LOGE("can not find node: %s", onnx_conv_weight.c_str())
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
return RET_ERROR;
}
std::vector<int> dims;
@@ -139,6 +143,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
}
attr->channelOut = dims[0];
// TODO(wangzhe) verify this code
attr->channelIn = dims[3] * attr->group;
}
attr->format = schema::Format_NCHW;
@@ -156,7 +161,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
if (attr->group != 1) {
if (!ParseGroupConvolution(op, attr)) {
delete attr;
// MS_LOGE("Convert Convolution to Depthwise failed");
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
}
}
@@ -169,4 +174,3 @@ OnnxNodeRegistrar g_onnxConvReluParser("ConvRelu", new OnnxConvParser());
OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_CONV_PARSER_H
#define MS_ONNX_CONV_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 2
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc View File

@@ -14,7 +14,8 @@
* limitations under the License.
*/

#include "mindspore/lite/tools/converter/parser/onnx/onnx_converter.h"
#include "tools/converter/parser/onnx/onnx_converter.h"
#include "tools/converter/parser/onnx/onnx_model_parser.h"

namespace mindspore {
namespace lite {


+ 2
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_converter.h View File

@@ -18,9 +18,8 @@
#define MS_ONNX_CONVERTER_H
#include <string>
#include <memory>
#include "mindspore/lite/tools/converter/converter.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h"
#include "mindspore/lite/tools/converter/graphdef_transform.h"
#include "tools/converter/converter.h"
#include "tools/converter/graphdef_transform.h"

namespace mindspore {
namespace lite {


+ 2
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc View File

@@ -17,11 +17,12 @@
#include <vector>
#include <memory>
#include <algorithm>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h"
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"

namespace mindspore {
namespace lite {
bool OnnxDeConvParser::ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr) {
MS_LOG(DEBUG) << "onnx DeConvParser";
if (attr == nullptr || attr->group != attr->channelOut) {
return false;
}


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_DECONV_PARSER_H
#define MS_ONNX_DECONV_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h"
#include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h"
#include "tools/converter/parser/onnx/onnx_dropout_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::DropoutT> attr(new schema::DropoutT());
MS_LOG(DEBUG) << "onnx DropoutParser";
std::unique_ptr<schema::DropoutT> attr(new schema::DropoutT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "ratio") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc View File

@@ -15,12 +15,13 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h"
#include "tools/converter/parser/onnx/onnx_elu_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::EluT> attr(new schema::EluT());
MS_LOG(DEBUG) << "onnx EluParser";
std::unique_ptr<schema::EluT> attr(new schema::EluT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "alpha") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_ELU_PARSER_H
#define MS_ONNX_ELU_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 5
- 5
mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc View File

@@ -15,17 +15,18 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h"
#include "tools/converter/parser/onnx/onnx_expand_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ExpandParser";
if (op != nullptr) {
std::unique_ptr<schema::BroadcastT> attr(new schema::BroadcastT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Broadcast;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
@@ -33,4 +34,3 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph,
OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_EXPAND_PARSER_H
#define MS_ONNX_EXPAND_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h"
#include "tools/converter/parser/onnx/onnx_flatten_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
MS_LOG(DEBUG) << "onnx FlattenParser";
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
int axis = 1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_FLATTEN_PARSER_H
#define MS_ONNX_FLATTEN_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h"
#include "tools/converter/parser/onnx/onnx_gather_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::GatherT> attr(new schema::GatherT());
MS_LOG(DEBUG) << "onnx GatherParser";
std::unique_ptr<schema::GatherT> attr(new schema::GatherT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_GATHER_PARSER_H
#define MS_ONNX_GATHER_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc View File

@@ -15,12 +15,13 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h"
#include "tools/converter/parser/onnx/onnx_lrn_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::LrnT> attr(new schema::LrnT());
MS_LOG(DEBUG) << "onnx LrnParser";
std::unique_ptr<schema::LrnT> attr(new schema::LrnT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "size") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_LRN_PARSER_H
#define MS_ONNX_LRN_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 5
- 6
mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc View File

@@ -15,14 +15,14 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h"
#include "tools/converter/parser/onnx/onnx_matmul_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::MatMulT> attr(new schema::MatMulT());
MS_LOG(DEBUG) << "onnx MatMulParser";
std::unique_ptr<schema::MatMulT> attr(new schema::MatMulT());
float alpha = 1.0f;
float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@@ -38,7 +38,7 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph,
}
}
if (alpha != 1 || beta != 1) {
// MS_LOGE("not support alpha * A * B + beta * C");
MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
return RET_PARAM_INVALID;
}

@@ -53,4 +53,3 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph,
OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_MATMUL_PARSER_H
#define MS_ONNX_MATMUL_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 79
- 65
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -18,7 +18,7 @@
#include <unordered_map>
#include <algorithm>
#include <utility>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h"
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "tools/common/graph_util.h"
#include "src/common/utils.h"

@@ -35,11 +35,12 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
{onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
{onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat}};
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}};

TypeId OnnxModelParser::GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
auto iter = TYPE_MAP.find(onnx_type);
if (iter == TYPE_MAP.end()) {
MS_LOG(ERROR) << "unsupported onnx data type: " << onnx_type;
return kTypeUnknown;
}
return iter->second;
@@ -56,7 +57,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) {
// MS_LOGE("get realpath %s fail", modelFile.c_str());
MS_LOG(ERROR) << "get realpath " << modelFile << " fail";
return RET_ERROR;
}
int fd = open(onnx_file.get(), O_RDONLY);
@@ -65,7 +66,7 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
bool ret = onnx_model->ParseFromCodedStream(&code_input);
if (!ret) {
// MS_LOGE("load onnx file failed");
MS_LOG(ERROR) << "load onnx file failed";
return RET_ERROR;
}
(void)close(fd);
@@ -73,46 +74,47 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go
}

STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) {
// MS_LOGD("set onnx constant tensors");
MS_LOG(DEBUG) << "set onnx constant tensors";
for (const auto &onnx_const_value : onnx_graph.initializer()) {
std::vector<int32_t> dims;
std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(dims));
auto data_type = GetDateTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type()));
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type()));
if (data_type == kTypeUnknown) {
// MS_LOGE("not support onnx type %d", static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type()));
MS_LOG(ERROR) << "not support onnx data type "
<< static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type());
return RET_ERROR;
}
std::unique_ptr<schema::TensorT> tensor(new (std::nothrow) schema::TensorT);
if (tensor == nullptr) {
// MS_LOGE("new tensor failed");
MS_LOG(ERROR) << "new tensor failed";
return RET_ERROR;
}
tensor->dataType = data_type;
tensor->format = schema::Format_NCHW;
for (const auto &it : dims) {
tensor->dims.emplace_back(it);
}
tensor->format = schema::Format_NCHW; // onnx use NCHW
std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(tensor->dims));
tensor->nodeType = schema::NodeType_ValueNode;
if (CopyOnnxTensorData(onnx_const_value, tensor.get())) {
MS_LOG(ERROR) << "copy onnx data failed";
return RET_ERROR;
}
// const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT);
// MS_LOGD("add const tensor: %s, index %d", onnx_const_value.name().c_str(), index)
// TODO(wangzhe) why use GRAPH_INPUT other than CONST(GRAPH_INPUT will add index to graphInputs)
const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT);
MS_LOG(DEBUG) << "add const tensor: " << onnx_const_value.name() << ", index " << index;
}
return RET_OK;
}

// TODO(wangzhe) seems AddTensorCache should be renamed to prepare tensor to add to tensor_cache
STATUS OnnxModelParser::AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor) {
auto data_type = GetDateTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
if (data_type == kTypeUnknown) {
// MS_LOGE("not support onnx type %d",
// static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
MS_LOG(ERROR) << "not support onnx type "
<< static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type());
return RET_ERROR;
}
tensor->dataType = data_type;
tensor->dims = GetDimsFromOnnxValue(proto);
tensor->format = schema::Format_NCHW;
tensor->nodeType = schema::NodeType_ValueNode;
// TODO(wangzhe) tensor->data and quantParams not set, should we need tensor_cache->AddTensor?
return RET_OK;
}

@@ -122,12 +124,14 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
auto ret = tensor_cache->FindTensor(input_value.name());
if (ret < 0) {
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT);
// TODO(wangzhe) why there is an addtensorCache?
if (AddTensorCache(input_value, tensor.get())) {
return RET_ERROR;
}
// TODO(wangzhe) why inputTensor is value and should be added into tensor_cache?
auto tensor_index = tensor_cache->AddTensor(input_value.name(), tensor.release(), GRAPH_INPUT);
graph->inputIndex.emplace_back(static_cast<uint32_t>(tensor_index));
// MS_LOGD("input_value name: %s, graph input index: %d", input_value.name().c_str(), tensor_index);
MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << tensor_index;
}
}
return RET_OK;
@@ -140,9 +144,10 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
if (AddTensorCache(output_value, tensor.get())) {
return RET_ERROR;
}
// TODO(wangzhe) why we need AddTensor at OutputTensor
auto tensor_index = tensor_cache->AddTensor(output_value.name(), tensor.release(), OP_OUTPUT);
graph->outputIndex.emplace_back(tensor_index);
// MS_LOGD("output_value name: %s, graph output index: %d", output_value.name().c_str(), tensor_index);
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << tensor_index;
}
return RET_OK;
}
@@ -151,7 +156,6 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
schema::MetaGraphT *graph, TensorCache *tensor_cache) {
std::unique_ptr<schema::CNodeT> dst_op_1(new schema::CNodeT);
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
// dst_op_1->fmkType = FmkType_ONNX;
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0);
std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)};
@@ -162,7 +166,6 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons

std::unique_ptr<schema::CNodeT> dst_op_2(new schema::CNodeT);
dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0);
// dst_op_2->fmkType = FmkType_ONNX;
ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get());
std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)};
std::vector<string> biasadd_outputs{onnx_node.output(0)};
@@ -181,7 +184,7 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
if (iter != onnx_node.attribute().end()) {
(void)shape.insert(shape.begin(), iter->ints().begin(), iter->ints().end());
std::for_each(shape.begin(), shape.end(), [](int sh) { /*MS_LOGD("shape: %d", sh);*/ });
std::for_each(shape.begin(), shape.end(), [](int sh) { MS_LOG(DEBUG) << "shape: " << sh; });
}
tensor->dims = shape;
tensor->format = schema::Format_NUM_OF_FORMAT;
@@ -210,51 +213,50 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
// todo: add * sizof(string)
data_size = data_count;
tensor->data.resize(data_size);
// MS_LOGD("tensor data size %lu, s: %lu", data_size, sizeof(iter->s().data()));
MS_LOG(DEBUG) << "tensor data size " << data_size << ", s: " << sizeof(iter->s().data());
if (memcpy_s(tensor->data.data(), data_size, iter->s().data(), data_size) != 0) {
// MS_LOGE("memcpy_s failed")
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
} else {
// MS_LOGE("unsupported data type %d", tensor->dataType);
MS_LOG(ERROR) << "unsupported data type " << tensor->dataType;
return RET_ERROR;
}
}
auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT);
// MS_LOGD("add given tensor: %d", index);
MS_LOG(DEBUG) << "add given tensor: " << index;
}
return RET_OK;
}

STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
TensorCache *tensor_cache) {
// change op_type() to name(), that is unique
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
// dst_op->fmkType = FmkType_ONNX;
// MS_LOGD("onnx op name %s, dst op name: %s, input size %d", onnx_node.op_type().c_str(), dst_op->name.c_str(),
// onnx_node.input_size());
MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size "
<< onnx_node.input_size();
// get the real op type
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op);
if (status != RET_OK) {
// MS_LOGE("parser onnx node attr failed");
MS_LOG(ERROR) << "parser onnx node attr failed";
return status;
}
// set op input index
std::vector<string> node_inputs;
(void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end());
if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) {
// MS_LOGE("SetOpInputIndex failed");
MS_LOG(ERROR) << "SetOpInputIndex failed";
return RET_ERROR;
}
// set op output index
std::vector<string> node_outputs;
(void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end());

if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) {
// MS_LOGE("SetOpOutputIndex failed");
MS_LOG(ERROR) << "SetOpOutputIndex failed";
return RET_ERROR;
}
return RET_OK;
@@ -286,7 +288,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
for (const auto &node : quant_node) {
std::unique_ptr<schema::QuantParamT> quant_param(new (std::nothrow) schema::QuantParamT());
if (quant_param == nullptr) {
// MS_LOGE("new QuantParamT failed, node: %s", dst_op->name.c_str());
MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name;
return;
}
int argNum = 0;
@@ -322,7 +324,7 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
const string &onnx_op_type, schema::CNodeT *dst_op) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
if (node_parser == nullptr) {
// MS_LOGE("not find %s, node parser is nullptr", onnx_op_type.c_str());
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
return RET_NULL_PTR;
}
return node_parser->Parse(onnx_graph, onnx_node, dst_op);
@@ -332,26 +334,32 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
schema::Format format = schema::Format_MAX;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "order") {
if (onnx_node_attr.name() == "order") { // do we need this code? onnx doc don't have order attr
MS_LOG(EXCEPTION) << "find order attr";
if (onnx_node_attr.s() == "NHWC") {
format = schema::Format_NHWC;
} else {
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s();
return RET_ERROR;
}
}
}
for (const auto &onnx_node_input : node_inputs) {
auto index = tensor_cache->FindTensor(onnx_node_input);
if (index < 0) {
// MS_LOG(ERROR) << onnx_node.name() << " input " << onnx_node_input << " index in tensor_cache " << index;
if (index < 0) { // TODO(wangzhe) can this be ignored? because it's no use
/*
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT);
index = tensor_cache->AddTensor(onnx_node_input, tensor.release(), OP_OUTPUT);
*/
MS_LOG(EXCEPTION) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
// MS_LOG(INFO) << "new index: " << index;
}
if (format != schema::Format_MAX) {
if (format != schema::Format_MAX) { // TODO(wangzhe) also this
auto inTensor = tensor_cache->GetCachedTensor().at(index);
inTensor->format = format;
}
// MS_LOGD("node: %s, input index: %d", onnx_node_input.c_str(), index);
MS_LOG(DEBUG) << "node: " << onnx_node_input << ", input index: " << index;
dst_op->inputIndex.emplace_back(index);
}
return RET_OK;
@@ -362,23 +370,30 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
for (const auto &onnx_node_output : node_outputs) {
auto index = tensor_cache->FindTensor(onnx_node_output);
if (index < 0) {
MS_LOG(INFO) << "output of node " << dst_op->name << " not in tensor_cache, creating";
MS_LOG(INFO) << "total " << node_outputs.size() << " outputs";
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT);

// GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
// tensor->dataType = ;
// tensor->dims = tflite_tensor->shape;
tensor->nodeType = schema::NodeType_Parameter;

index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT);
}
// MS_LOGD("node: %s, input index: %d", onnx_node_output.c_str(), index);
MS_LOG(DEBUG) << "node: " << onnx_node_output << ", input index: " << index;
dst_op->outputIndex.emplace_back(index);
}
return RET_OK;
}

STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
schema::TensorT *tensor) {
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
size_t data_count = 1;
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
size_t data_size = 0;
const void *tensor_data = nullptr;
switch (tensor->dataType) {
case kNumberTypeFloat:
case kNumberTypeFloat32:
data_size = data_count * sizeof(float);
if (onnx_const_value.float_data_size() == 0) {
tensor_data = onnx_const_value.raw_data().data();
@@ -408,12 +423,12 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
tensor_data = onnx_const_value.raw_data().data();
break;
default:
// MS_LOGE("unsupported data type %d", tensor->dataType);
MS_LOG(ERROR) << "unsupported data type " << tensor->dataType;
return RET_ERROR;
}
tensor->data.resize(data_size);
if (memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
// MS_LOGE("memcpy_s failed")
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
return RET_OK;
@@ -441,36 +456,37 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
}
}

MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
// MS_LOGE("Input illegal: modelFile must be *.onnx");
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
return nullptr;
}
std::unique_ptr<schema::MetaGraphT> dst_graph(new schema::MetaGraphT());
onnx::ModelProto onnx_model;
if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) {
// MS_LOGE("read onnx model fail");
MS_LOG(ERROR) << "read onnx model fail";
return nullptr;
}
const onnx::GraphProto &onnx_graph = onnx_model.graph();
// MS_LOGI("model producer name: %s, graph name: %s", onnx_model.producer_name().c_str(), onnx_graph.name().c_str());
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
TensorCache tensor_cache;
dst_graph->name = onnx_graph.name();
// dst_graph->name = onnx_graph.name(); // this is not used
// find out input names and const names
FindGraphInputAndConst(onnx_graph);
// set const tensor
if (SetGraphConstTensor(onnx_graph, &tensor_cache)) {
// MS_LOGE("SetGraphConstTensor failed");
MS_LOG(ERROR) << "SetGraphConstTensor failed";
return nullptr;
}
// init onnx model graph input tensor
if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) {
// MS_LOGE("SetGraphInputTensor failed");
MS_LOG(ERROR) << "SetGraphInputTensor failed";
return nullptr;
}
// init onnx model graph output tensor
if (SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) {
// MS_LOGE("SetGraphOutputTensor failed");
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
return nullptr;
}
// init op node input/output tensor, and dst_op attr
@@ -481,7 +497,7 @@ MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::stri
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
auto status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache);
if (status != RET_OK) {
// MS_LOGE("ParseOnnxGivenFillNode failed: %d", status);
MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status;
return nullptr;
}
continue;
@@ -489,18 +505,16 @@ MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::stri

std::unique_ptr<schema::CNodeT> dst_op(new schema::CNodeT);
std::unique_ptr<schema::TensorT> dst_tensor(new schema::TensorT);
if (ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache)) {
// MS_LOGE("parse node %s failed", onnx_node.op_type().c_str())
auto status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache);
if (status != RET_OK) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed";
return nullptr;
}
dst_graph->nodes.emplace_back(std::move(dst_op));
}
SetAllTensors(tensor_cache, dst_graph.get());
dst_graph->mempoolSize = 0;
dst_graph->name = GetModelName(modelFile);
return dst_graph.release();
// return Fb2Anf(dst_graph.release());
}
} // namespace lite
} // namespace mindspore


+ 10
- 16
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -27,9 +27,10 @@
#include <memory>
#include <set>
#include "securec/include/securec.h"
#include "mindspore/lite/tools/converter/model_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/onnx/onnx.pb.h"

namespace mindspore {
namespace lite {
@@ -41,30 +42,24 @@ class OnnxModelParser : public ModelParser {
const QuantType &quantType = QuantType_QUANT_NONE) override;

private:
TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor, TensorCache *tensor_cache);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph,
TensorCache *tensor_cache);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const string &onnx_op_type, schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
schema::TensorT *dst_tensor, TensorCache *tensor_cache);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs,
schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef);
@@ -78,4 +73,3 @@ class OnnxModelParser : public ModelParser {
} // namespace mindspore

#endif // MS_ONNX_MODEL_PARSER_H


+ 1
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"

namespace mindspore {
namespace lite {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h View File

@@ -19,9 +19,9 @@

#include <string>
#include "google/protobuf/message.h"
#include "mindspore/lite/tools/converter/proto/onnx.pb.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/common/node_util.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "schema/inner/model_generated.h"

// using namespace std;



+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include <string>

namespace mindspore {
@@ -33,13 +33,14 @@ OnnxNodeParser *OnnxNodeParserRegistry::GetNodeParser(const std::string &name) {
if (it != parsers.end()) {
return it->second;
}
/* should not support vague name, otherwise may get wrong parser. ex. PRelu and Relu
for (auto const &i : parsers) {
if (name.find(i.first) != std::string::npos) {
return i.second;
}
}
*/
return nullptr;
}
} // namespace lite
} // namespace mindspore


+ 1
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h View File

@@ -19,8 +19,7 @@

#include <string>
#include <unordered_map>
#include "mindspore/lite/tools/converter/proto/onnx.pb.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"

namespace mindspore {
namespace lite {


+ 6
- 5
mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc View File

@@ -15,12 +15,13 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h"
#include "tools/converter/parser/onnx/onnx_pad_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::PadT> attr(new schema::PadT());
MS_LOG(DEBUG) << "onnx PadParser";
std::unique_ptr<schema::PadT> attr(new schema::PadT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "pads") {
@@ -33,11 +34,11 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
} else if (attribute_name == "mode") {
const auto &mode = onnx_node_attr.s();
if (mode == "constant") {
attr->paddingmode = schema::PaddingMode_CONSTANT;
attr->paddingMode = schema::PaddingMode_CONSTANT;
} else if (mode == "reflect") {
attr->paddingmode = schema::PaddingMode_REFLECT;
attr->paddingMode = schema::PaddingMode_REFLECT;
} else if (mode == "edge") {
attr->paddingmode = schema::PaddingMode_SYMMETRIC;
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
}
}
}


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_LRN_PARSER_H
#define MS_ONNX_LRN_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 5
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc View File

@@ -15,12 +15,13 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h"
#include "tools/converter/parser/onnx/onnx_pool_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::PoolingT> attr(new schema::PoolingT());
MS_LOG(DEBUG) << "onnx PoolParser";
std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT());

const auto &pool_type = onnx_node.op_type();
if (pool_type == "MaxPool") {
@@ -41,6 +42,8 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}

attr->roundMode = schema::RoundMode_FLOOR;
attr->strideW = 1;
attr->strideH = 1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "kernel_shape") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_POOL_PARSER_H
#define MS_ONNX_POOL_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h"
#include "tools/converter/parser/onnx/onnx_reduce_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::ReduceT> attr(new schema::ReduceT());
MS_LOG(DEBUG) << "onnx ReduceParser";
std::unique_ptr<schema::ReduceT> attr(new schema::ReduceT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_REDUCE_PARSER_H
#define MS_ONNX_REDUCE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 28
- 20
mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc View File

@@ -16,12 +16,13 @@

#include <vector>
#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h"
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
#include "securec/include/securec.h"
namespace mindspore {
namespace lite {
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
MS_LOG(DEBUG) << "onnx ReluParser";
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
const auto &relu_type = onnx_node.op_type();
if (relu_type == "Relu") {
attr->type = schema::ActivationType_RELU;
@@ -30,44 +31,52 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}

if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PReluParser";
if (onnx_node.input_size() != 2) {
// MS_LOGE("input num is not 2")
MS_LOG(ERROR) << "input num is not 2";
return RET_PARAM_INVALID;
}
unique_ptr<schema::PreluT> attr(new schema::PreluT());
std::unique_ptr<schema::CaffePReLUT> attr(new schema::CaffePReLUT());
std::vector<onnx::TensorProto> params;
for (int i = 0; i < onnx_node.input_size(); ++i) {
const auto &input_name = onnx_node.input(i);
for ( const auto &it : onnx_graph.initializer() ) {
if (it.name() == "input_name") {
params.push_back(it);
break;
}
const auto &input_name = onnx_node.input(1);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == input_name) {
params.push_back(it);
break;
}
}

const onnx::TensorProto *slope = &params[0];
if (slope == nullptr) {
// MS_LOGE("input error")
MS_LOG(ERROR) << "input error";
return RET_PARAM_INVALID;
}
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) {
// MS_LOGE("memcpy_s failed")
return RET_ERROR;
if (slope_size == 1) {
attr->slope.push_back(*slope_raw_data);
attr->channelShared = true;
} else { // TODO(wangzhe) we don't check input tensor's channel size, this may cause problem
attr->slope.resize(slope_size);
attr->channelShared = false;
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}

if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive->value.value = attr.release();
}
return RET_OK;
@@ -75,7 +84,6 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph,

OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser());
OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxLeakeyReluParser());
OnnxNodeRegistrar g_onnxPReluParser("Prelu", new OnnxPReluParser());
OnnxNodeRegistrar g_onnxPReluParser("PRelu", new OnnxPReluParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_RELU_PARSER_H
#define MS_ONNX_RELU_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 15
- 16
mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc View File

@@ -16,17 +16,17 @@

#include <vector>
#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h"
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
attr->format = schema::Format_NHWC;
MS_LOG(DEBUG) << "onnx ReshapeParser";
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
attr->format = schema::Format_NCHW;
std::vector<onnx::TensorProto> params;
// TODO(wangzhe) shape may also come from other op, there need refactor to introduce tensor_cache
for (int i = 0; i < onnx_node.input_size(); ++i) {
const auto &input_name = onnx_node.input(i);
for (const auto &it : onnx_graph.initializer()) {
@@ -37,16 +37,16 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph,
}
}
if (params.empty()) {
return RET_OK;
}
if (params.size() != 1) {
// MS_LOGE("input num is ,not equal 1", params.size())
return RET_PARAM_INVALID;
}
MS_LOG(DEBUG) << "shape from another op other than const initializer";
} else {
if (params.size() != 1) {
MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1";
return RET_PARAM_INVALID;
}

auto pre_shape = params[0];
for (int i = 0; i < pre_shape.dims_size(); ++i) {
attr->shape.emplace_back(params[0].dims(i));
for (int i = 0; i < params[0].int64_data_size(); ++i) {
attr->shape.emplace_back(params[0].int64_data(i));
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
@@ -59,4 +59,3 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph,
OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_RESHAPE_PARSER_H
#define MS_ONNX_RESHAPE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 4
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc View File

@@ -15,17 +15,19 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h"
#include "tools/converter/parser/onnx/onnx_shape_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ShapeParser";
if (op != nullptr) {
std::unique_ptr<schema::ShapeT> attr(new schema::ShapeT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Shape;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_SHAPE_PARSER_H
#define MS_ONNX_SHAPE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h"
#include "tools/converter/parser/onnx/onnx_sigmoid_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
MS_LOG(DEBUG) << "onnx SigmoidParser";
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
attr->type = schema::ActivationType_SIGMOID;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_SIGMOID_PARSER_H
#define MS_ONNX_SIGMOID_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h"
#include "tools/converter/parser/onnx/onnx_slice_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::SliceT> attr(new schema::SliceT());
MS_LOG(DEBUG) << "onnx SliceParser";
std::unique_ptr<schema::SliceT> attr(new schema::SliceT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "starts") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_SLICE_PARSER_H
#define MS_ONNX_SLICE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h"
#include "tools/converter/parser/onnx/onnx_softmax_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::SoftMaxT> attr(new schema::SoftMaxT());
MS_LOG(DEBUG) << "onnx SoftMaxParser";
std::unique_ptr<schema::SoftMaxT> attr(new schema::SoftMaxT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_SOFTMAX_PARSER_H
#define MS_ONNX_SOFTMAX_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 5
- 6
mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc View File

@@ -15,14 +15,14 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h"
#include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxSPaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::SpaceToDepthT> attr(new schema::SpaceToDepthT());
MS_LOG(DEBUG) << "onnx SpaceToDepthParser";
std::unique_ptr<schema::SpaceToDepthT> attr(new schema::SpaceToDepthT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {
@@ -37,7 +37,6 @@ STATUS OnnxSPaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph,
return RET_OK;
}

OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSPaceToDepthParser());
OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser());
} // namespace lite
} // namespace mindspore


+ 4
- 4
mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h View File

@@ -17,14 +17,14 @@
#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H
#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {
class OnnxSPaceToDepthParser : public OnnxNodeParser {
class OnnxSpaceToDepthParser : public OnnxNodeParser {
public:
OnnxSPaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h"
#include "tools/converter/parser/onnx/onnx_squeeze_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::SqueezeT> attr(new schema::SqueezeT());
MS_LOG(DEBUG) << "onnx SqueezeParser";
std::unique_ptr<schema::SqueezeT> attr(new schema::SqueezeT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_SQUEEZE_PARSER_H
#define MS_ONNX_SQUEEZE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 4
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc View File

@@ -15,15 +15,17 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h"
#include "tools/converter/parser/onnx/onnx_tile_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TileParser";
if (op != nullptr) {
std::unique_ptr<schema::TileT> attr(new schema::TileT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = nullptr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
@@ -31,4 +33,3 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser());
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_TILE_PARSER_H
#define MS_ONNX_TILE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h"
#include "tools/converter/parser/onnx/onnx_transpose_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::TransposeT> attr(new schema::TransposeT());
MS_LOG(DEBUG) << "onnx TransposeParser";
std::unique_ptr<schema::TransposeT> attr(new schema::TransposeT());
attr->conjugate = false;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_TRANSPOSE_PARSER_H
#define MS_ONNX_TRANSPOSE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h"
#include "tools/converter/parser/onnx/onnx_unsample_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::UpsampleT> attr(new schema::UpsampleT());
MS_LOG(DEBUG) << "onnx UpsampleParser";
std::unique_ptr<schema::UpsampleT> attr(new schema::UpsampleT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_UPSAMPLE_PARSER_H
#define MS_ONNX_UPSAMPLE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 3
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc View File

@@ -15,14 +15,15 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h"
#include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
unique_ptr<schema::UnsqueezeT> attr(new schema::UnsqueezeT());
MS_LOG(DEBUG) << "onnx UnSqueezeParser";
std::unique_ptr<schema::UnsqueezeT> attr(new schema::UnsqueezeT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_UNSQUEEZE_PARSER_H
#define MS_ONNX_UNSQUEEZE_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


+ 2
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc View File

@@ -15,13 +15,14 @@
*/

#include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h"
#include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h"

namespace mindspore {
namespace lite {
STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UnusefulNodeParser";
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
if (onnx_node.op_type() == "Int8Quantize") {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h View File

@@ -17,8 +17,8 @@
#ifndef MS_ONNX_UNUSEFUL_PARSER_H
#define MS_ONNX_UNUSEFUL_PARSER_H

#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {


Loading…
Cancel
Save