Browse Source

!11743 fix parser static check

From: @lyvette
Reviewed-by: 
Signed-off-by:
tags/v1.1.1
mindspore-ci-bot Gitee 5 years ago
parent
commit
31e2fc7d10
100 changed files with 838 additions and 1465 deletions
  1. +169
    -169
      mindspore/lite/src/ops/ops_utils.cc
  2. +14
    -30
      mindspore/lite/tools/converter/parser/caffe/caffe_activation_parser.cc
  3. +8
    -12
      mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc
  4. +6
    -10
      mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc
  5. +3
    -7
      mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc
  6. +16
    -23
      mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc
  7. +7
    -11
      mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc
  8. +16
    -25
      mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc
  9. +6
    -10
      mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc
  10. +3
    -7
      mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.cc
  11. +9
    -12
      mindspore/lite/tools/converter/parser/caffe/caffe_exp_parser.cc
  12. +2
    -6
      mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc
  13. +7
    -11
      mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc
  14. +7
    -10
      mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc
  15. +8
    -8
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
  16. +1
    -1
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h
  17. +3
    -7
      mindspore/lite/tools/converter/parser/caffe/caffe_permute_parser.cc
  18. +16
    -24
      mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc
  19. +5
    -9
      mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc
  20. +6
    -10
      mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc
  21. +9
    -13
      mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc
  22. +3
    -7
      mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc
  23. +3
    -21
      mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc
  24. +0
    -2
      mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h
  25. +7
    -11
      mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.cc
  26. +4
    -8
      mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc
  27. +5
    -8
      mindspore/lite/tools/converter/parser/caffe/caffe_tile_parser.cc
  28. +22
    -46
      mindspore/lite/tools/converter/parser/onnx/onnx_activation_parser.cc
  29. +2
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.cc
  30. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc
  31. +62
    -199
      mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc
  32. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc
  33. +2
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc
  34. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc
  35. +6
    -10
      mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc
  36. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc
  37. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc
  38. +5
    -11
      mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc
  39. +1
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h
  40. +30
    -34
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc
  41. +19
    -20
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc
  42. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc
  43. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc
  44. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc
  45. +2
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc
  46. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc
  47. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc
  48. +11
    -22
      mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc
  49. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h
  50. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc
  51. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc
  52. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc
  53. +6
    -10
      mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc
  54. +8
    -12
      mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc
  55. +4
    -8
      mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc
  56. +30
    -45
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  57. +5
    -5
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  58. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc
  59. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc
  60. +9
    -13
      mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc
  61. +21
    -29
      mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc
  62. +6
    -10
      mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc
  63. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.cc
  64. +11
    -15
      mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc
  65. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc
  66. +11
    -14
      mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc
  67. +2
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc
  68. +8
    -11
      mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc
  69. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc
  70. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc
  71. +7
    -10
      mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc
  72. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc
  73. +2
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc
  74. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc
  75. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc
  76. +3
    -7
      mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc
  77. +7
    -11
      mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc
  78. +7
    -10
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc
  79. +24
    -72
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc
  80. +3
    -7
      mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc
  81. +2
    -6
      mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc
  82. +3
    -7
      mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc
  83. +3
    -7
      mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc
  84. +12
    -16
      mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc
  85. +2
    -6
      mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc
  86. +3
    -7
      mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc
  87. +2
    -6
      mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc
  88. +4
    -8
      mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc
  89. +3
    -7
      mindspore/lite/tools/converter/parser/tf/tf_pack_parser.cc
  90. +5
    -9
      mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.cc
  91. +6
    -9
      mindspore/lite/tools/converter/parser/tf/tf_range_parser.cc
  92. +9
    -13
      mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc
  93. +2
    -6
      mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.cc
  94. +4
    -8
      mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc
  95. +2
    -6
      mindspore/lite/tools/converter/parser/tf/tf_round_parser.cc
  96. +2
    -6
      mindspore/lite/tools/converter/parser/tf/tf_shape_parser.cc
  97. +7
    -11
      mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc
  98. +3
    -7
      mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.cc
  99. +7
    -11
      mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc
  100. +4
    -8
      mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.cc

+ 169
- 169
mindspore/lite/src/ops/ops_utils.cc View File

@@ -651,184 +651,184 @@ schema::PrimitiveT *ZerosLikePrimitiveCreator(const AnfNodePtr &node) {
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}

RegistryMSOps g_AbsPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_ActivationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
RegistryMSOps g_ActivationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator);
RegistryMSOps g_AddPrimitiveCreatorRegistry("Add", AddFusionPrimitiveCreator);
RegistryMSOps g_AddFusionPrimitiveCreatorRegistry("AddFusion", AddFusionPrimitiveCreator);
RegistryMSOps g_AddGradPrimitiveCreatorRegistry("AddGrad", AddGradPrimitiveCreator);
RegistryMSOps g_AdderPrimitiveCreatorRegistry("Adder", AdderFusionPrimitiveCreator);
RegistryMSOps g_AdderFusionPrimitiveCreatorRegistry("AdderFusion", AdderFusionPrimitiveCreator);
RegistryMSOps g_AddNPrimitiveCreatorRegistry("AddN", AddNPrimitiveCreator);
RegistryMSOps g_AllPrimitiveCreatorRegistry("All", AllPrimitiveCreator);
RegistryMSOps g_ApplyMomentumPrimitiveCreatorRegistry("ApplyMomentum", ApplyMomentumPrimitiveCreator);
RegistryMSOps g_ArgMaxPrimitiveCreatorRegistry("ArgMax", ArgMaxFusionPrimitiveCreator);
RegistryMSOps g_ArgMaxFusionPrimitiveCreatorRegistry("ArgMaxFusion", ArgMaxFusionPrimitiveCreator);
RegistryMSOps g_ArgMinPrimitiveCreatorRegistry("ArgMin", ArgMinFusionPrimitiveCreator);
RegistryMSOps g_ArgMinFusionPrimitiveCreatorRegistry("ArgMinFusion", ArgMinFusionPrimitiveCreator);
RegistryMSOps g_AssertPrimitiveCreatorRegistry("Assert", AssertPrimitiveCreator);
RegistryMSOps g_AssignPrimitiveCreatorRegistry("Assign", AssignPrimitiveCreator);
RegistryMSOps g_AssignAddPrimitiveCreatorRegistry("AssignAdd", AssignAddPrimitiveCreator);
RegistryMSOps g_AvgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator);
RegistryMSOps g_AvgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator);
RegistryMSOps g_BatchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator);
RegistryMSOps g_BatchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator);
RegistryMSOps g_BatchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator);
RegistryMSOps g_BiasAddPrimitiveCreatorRegistry("BiasAdd", BiasAddPrimitiveCreator);
RegistryMSOps g_BNGradPrimitiveCreatorRegistry("BNGrad", BNGradPrimitiveCreator);
RegistryMSOps g_BroadcastToPrimitiveCreatorRegistry("BroadcastTo", BroadcastToPrimitiveCreator);
RegistryMSOps g_CastPrimitiveCreatorRegistry("Cast", CastPrimitiveCreator);
RegistryMSOps g_CeilPrimitiveCreatorRegistry("Ceil", CeilPrimitiveCreator);
RegistryMSOps g_ClipPrimitiveCreatorRegistry("Clip", ClipPrimitiveCreator);
RegistryMSOps g_ConcatPrimitiveCreatorRegistry("Concat", ConcatPrimitiveCreator);
// RegistryMSOps g_ControlDependPrimitiveCreatorRegistry("ControlDepend", ControlDependPrimitiveCreator);
RegistryMSOps g_Conv2DBackpropFilterFusionPrimitiveCreatorRegistry("Conv2DBackpropFilterFusion",
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
RegistryMSOps g_activationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator);
RegistryMSOps g_addPrimitiveCreatorRegistry("Add", AddFusionPrimitiveCreator);
RegistryMSOps g_addFusionPrimitiveCreatorRegistry("AddFusion", AddFusionPrimitiveCreator);
RegistryMSOps g_addGradPrimitiveCreatorRegistry("AddGrad", AddGradPrimitiveCreator);
RegistryMSOps g_adderPrimitiveCreatorRegistry("Adder", AdderFusionPrimitiveCreator);
RegistryMSOps g_adderFusionPrimitiveCreatorRegistry("AdderFusion", AdderFusionPrimitiveCreator);
RegistryMSOps g_addNPrimitiveCreatorRegistry("AddN", AddNPrimitiveCreator);
RegistryMSOps g_allPrimitiveCreatorRegistry("All", AllPrimitiveCreator);
RegistryMSOps g_applyMomentumPrimitiveCreatorRegistry("ApplyMomentum", ApplyMomentumPrimitiveCreator);
RegistryMSOps g_argMaxPrimitiveCreatorRegistry("ArgMax", ArgMaxFusionPrimitiveCreator);
RegistryMSOps g_argMaxFusionPrimitiveCreatorRegistry("ArgMaxFusion", ArgMaxFusionPrimitiveCreator);
RegistryMSOps g_argMinPrimitiveCreatorRegistry("ArgMin", ArgMinFusionPrimitiveCreator);
RegistryMSOps g_argMinFusionPrimitiveCreatorRegistry("ArgMinFusion", ArgMinFusionPrimitiveCreator);
RegistryMSOps g_assertPrimitiveCreatorRegistry("Assert", AssertPrimitiveCreator);
RegistryMSOps g_assignPrimitiveCreatorRegistry("Assign", AssignPrimitiveCreator);
RegistryMSOps g_assignAddPrimitiveCreatorRegistry("AssignAdd", AssignAddPrimitiveCreator);
RegistryMSOps g_avgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator);
RegistryMSOps g_avgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator);
RegistryMSOps g_batchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator);
RegistryMSOps g_batchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator);
RegistryMSOps g_batchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator);
RegistryMSOps g_biasAddPrimitiveCreatorRegistry("BiasAdd", BiasAddPrimitiveCreator);
RegistryMSOps g_bNGradPrimitiveCreatorRegistry("BNGrad", BNGradPrimitiveCreator);
RegistryMSOps g_broadcastToPrimitiveCreatorRegistry("BroadcastTo", BroadcastToPrimitiveCreator);
RegistryMSOps g_castPrimitiveCreatorRegistry("Cast", CastPrimitiveCreator);
RegistryMSOps g_ceilPrimitiveCreatorRegistry("Ceil", CeilPrimitiveCreator);
RegistryMSOps g_clipPrimitiveCreatorRegistry("Clip", ClipPrimitiveCreator);
RegistryMSOps g_concatPrimitiveCreatorRegistry("Concat", ConcatPrimitiveCreator);
// RegistryMSOps g_controlDependPrimitiveCreatorRegistry("ControlDepend", ControlDependPrimitiveCreator);
RegistryMSOps g_conv2DBackpropFilterFusionPrimitiveCreatorRegistry("Conv2DBackpropFilterFusion",
Conv2DBackpropFilterFusionPrimitiveCreator);
RegistryMSOps g_Conv2DBackpropInputFusionPrimitiveCreatorRegistry("Conv2DBackpropInputFusion",
RegistryMSOps g_conv2DBackpropInputFusionPrimitiveCreatorRegistry("Conv2DBackpropInputFusion",
Conv2DBackpropInputFusionPrimitiveCreator);
RegistryMSOps g_Conv2DPrimitiveCreatorRegistry("Conv2D", Conv2DFusionPrimitiveCreator);
RegistryMSOps g_Conv2DFusionPrimitiveCreatorRegistry("Conv2DFusion", Conv2DFusionPrimitiveCreator);
RegistryMSOps g_Conv2dTransposePrimitiveCreatorRegistry("Conv2dTranspose", Conv2dTransposeFusionPrimitiveCreator);
RegistryMSOps g_Conv2dTransposeFusionPrimitiveCreatorRegistry("Conv2dTransposeFusion",
RegistryMSOps g_conv2DPrimitiveCreatorRegistry("Conv2D", Conv2DFusionPrimitiveCreator);
RegistryMSOps g_conv2DFusionPrimitiveCreatorRegistry("Conv2DFusion", Conv2DFusionPrimitiveCreator);
RegistryMSOps g_conv2dTransposePrimitiveCreatorRegistry("Conv2dTranspose", Conv2dTransposeFusionPrimitiveCreator);
RegistryMSOps g_conv2dTransposeFusionPrimitiveCreatorRegistry("Conv2dTransposeFusion",
Conv2dTransposeFusionPrimitiveCreator);
RegistryMSOps g_ConstantOfShapePrimitiveCreatorRegistry("ConstantOfShape", ConstantOfShapePrimitiveCreator);
RegistryMSOps g_CosPrimitiveCreatorRegistry("Cos", CosPrimitiveCreator);
RegistryMSOps g_CropPrimitiveCreatorRegistry("Crop", CropPrimitiveCreator);
RegistryMSOps g_CustomExtractFeaturesPrimitiveCreatorRegistry("CustomExtractFeatures",
RegistryMSOps g_constantOfShapePrimitiveCreatorRegistry("ConstantOfShape", ConstantOfShapePrimitiveCreator);
RegistryMSOps g_cosPrimitiveCreatorRegistry("Cos", CosPrimitiveCreator);
RegistryMSOps g_cropPrimitiveCreatorRegistry("Crop", CropPrimitiveCreator);
RegistryMSOps g_customExtractFeaturesPrimitiveCreatorRegistry("CustomExtractFeatures",
CustomExtractFeaturesPrimitiveCreator);
RegistryMSOps g_CustomNormalizePrimitiveCreatorRegistry("CustomNormalize", CustomNormalizePrimitiveCreator);
RegistryMSOps g_CustomPredictPrimitiveCreatorRegistry("CustomPredict", CustomPredictPrimitiveCreator);
RegistryMSOps g_DependPrimitiveCreatorRegistry("Depend", DependPrimitiveCreator);
RegistryMSOps g_DepthToSpacePrimitiveCreatorRegistry("DepthToSpace", DepthToSpacePrimitiveCreator);
RegistryMSOps g_DetectionPostProcessPrimitiveCreatorRegistry("DetectionPostProcess",
RegistryMSOps g_customNormalizePrimitiveCreatorRegistry("CustomNormalize", CustomNormalizePrimitiveCreator);
RegistryMSOps g_customPredictPrimitiveCreatorRegistry("CustomPredict", CustomPredictPrimitiveCreator);
RegistryMSOps g_dependPrimitiveCreatorRegistry("Depend", DependPrimitiveCreator);
RegistryMSOps g_depthToSpacePrimitiveCreatorRegistry("DepthToSpace", DepthToSpacePrimitiveCreator);
RegistryMSOps g_detectionPostProcessPrimitiveCreatorRegistry("DetectionPostProcess",
DetectionPostProcessPrimitiveCreator);
RegistryMSOps g_DivPrimitiveCreatorRegistry("Div", DivFusionPrimitiveCreator);
RegistryMSOps g_DivFusionPrimitiveCreatorRegistry("DivFusion", DivFusionPrimitiveCreator);
RegistryMSOps g_DivGradPrimitiveCreatorRegistry("DivGrad", DivGradPrimitiveCreator);
RegistryMSOps g_DropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreator);
RegistryMSOps g_DropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator);
RegistryMSOps g_EltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator);
RegistryMSOps g_EluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator);
RegistryMSOps g_EqualPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator);
RegistryMSOps g_EmbeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion",
RegistryMSOps g_divPrimitiveCreatorRegistry("Div", DivFusionPrimitiveCreator);
RegistryMSOps g_divFusionPrimitiveCreatorRegistry("DivFusion", DivFusionPrimitiveCreator);
RegistryMSOps g_divGradPrimitiveCreatorRegistry("DivGrad", DivGradPrimitiveCreator);
RegistryMSOps g_dropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreator);
RegistryMSOps g_dropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator);
RegistryMSOps g_eltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator);
RegistryMSOps g_eluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator);
RegistryMSOps g_equalPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator);
RegistryMSOps g_embeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion",
EmbeddingLookupFusionPrimitiveCreator);
RegistryMSOps g_ExpandDimsPrimitiveCreatorRegistry("ExpandDims", ExpandDimsPrimitiveCreator);
RegistryMSOps g_ExpPrimitiveCreatorRegistry("Exp", ExpFusionPrimitiveCreator);
RegistryMSOps g_ExpFusionPrimitiveCreatorRegistry("ExpFusion", ExpFusionPrimitiveCreator);
RegistryMSOps g_FftImagPrimitiveCreatorRegistry("FftImag", FftImagPrimitiveCreator);
RegistryMSOps g_FftRealPrimitiveCreatorRegistry("FftReal", FftRealPrimitiveCreator);
RegistryMSOps g_FillPrimitiveCreatorRegistry("Fill", FillPrimitiveCreator);
RegistryMSOps g_FlattenPrimitiveCreatorRegistry("Flatten", FlattenPrimitiveCreator);
RegistryMSOps g_FlattenGradPrimitiveCreatorRegistry("FlattenGrad", FlattenGradPrimitiveCreator);
RegistryMSOps g_FloorPrimitiveCreatorRegistry("Floor", FloorPrimitiveCreator);
RegistryMSOps g_FloorDivPrimitiveCreatorRegistry("FloorDiv", FloorDivPrimitiveCreator);
RegistryMSOps g_FloorModPrimitiveCreatorRegistry("FloorMod", FloorModPrimitiveCreator);
RegistryMSOps g_FullConnectionPrimitiveCreatorRegistry("FullConnection", FullConnectionPrimitiveCreator);
RegistryMSOps g_FusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator);
RegistryMSOps g_GatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator);
RegistryMSOps g_GatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator);
RegistryMSOps g_GreaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator);
RegistryMSOps g_GreaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator);
RegistryMSOps g_HashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator);
RegistryMSOps g_InstanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator);
RegistryMSOps g_LayerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator);
RegistryMSOps g_LayerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator);
RegistryMSOps g_LeakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator);
RegistryMSOps g_LessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator);
RegistryMSOps g_LessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator);
RegistryMSOps g_LogPrimitiveCreatorRegistry("Log", LogPrimitiveCreator);
RegistryMSOps g_LogGradPrimitiveCreatorRegistry("LogGrad", LogGradPrimitiveCreator);
RegistryMSOps g_LogicalAndPrimitiveCreatorRegistry("LogicalAnd", LogicalAndPrimitiveCreator);
RegistryMSOps g_LogicalNotPrimitiveCreatorRegistry("LogicalNot", LogicalNotPrimitiveCreator);
RegistryMSOps g_LogicalOrPrimitiveCreatorRegistry("LogicalOr", LogicalOrPrimitiveCreator);
RegistryMSOps g_LpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNormalizationPrimitiveCreator);
RegistryMSOps g_LrnPrimitiveCreatorRegistry("Lrn", LrnPrimitiveCreator);
RegistryMSOps g_LshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator);
RegistryMSOps g_LSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator);
RegistryMSOps g_L2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator);
RegistryMSOps g_MatMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator);
RegistryMSOps g_MaximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator);
RegistryMSOps g_MaximumGradPrimitiveCreatorRegistry("MaximumGrad", MaximumGradPrimitiveCreator);
RegistryMSOps g_MaxPoolPrimitiveCreatorRegistry("MaxPool", MaxPoolFusionPrimitiveCreator);
RegistryMSOps g_MaxPoolFusionPrimitiveCreatorRegistry("MaxPoolFusion", MaxPoolFusionPrimitiveCreator);
RegistryMSOps g_MergePrimitiveCreatorRegistry("Merge", MergePrimitiveCreator);
RegistryMSOps g_MinimumPrimitiveCreatorRegistry("Minimum", MinimumPrimitiveCreator);
RegistryMSOps g_MinimumGradPrimitiveCreatorRegistry("MinimumGrad", MinimumGradPrimitiveCreator);
RegistryMSOps g_ModPrimitiveCreatorRegistry("Mod", ModPrimitiveCreator);
RegistryMSOps g_MulPrimitiveCreatorRegistry("Mul", MulFusionPrimitiveCreator);
RegistryMSOps g_MulMulFusionPrimitiveCreatorRegistry("MulFusion", MulFusionPrimitiveCreator);
RegistryMSOps g_MulGradPrimitiveCreatorRegistry("MulGrad", MulGradPrimitiveCreator);
RegistryMSOps g_NegPrimitiveCreatorRegistry("Neg", NegPrimitiveCreator);
RegistryMSOps g_NegGradPrimitiveCreatorRegistry("NegGrad", NegGradPrimitiveCreator);
RegistryMSOps g_NonMaxSuppressionPrimitiveCreatorRegistry("NonMaxSuppression", NonMaxSuppressionPrimitiveCreator);
RegistryMSOps g_NotEqualPrimitiveCreatorRegistry("NotEqual", NotEqualPrimitiveCreator);
RegistryMSOps g_OneHotPrimitiveCreatorRegistry("OneHot", OneHotPrimitiveCreator);
RegistryMSOps g_OnesLikePrimitiveCreatorRegistry("OnesLike", OnesLikePrimitiveCreator);
RegistryMSOps g_PadPrimitiveCreatorRegistry("Pad", PadFusionPrimitiveCreator);
RegistryMSOps g_PadFusionPrimitiveCreatorRegistry("PadFusion", PadFusionPrimitiveCreator);
RegistryMSOps g_PartialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFusionPrimitiveCreator);
RegistryMSOps g_PowerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator);
RegistryMSOps g_PowFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator);
RegistryMSOps g_PReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator);
RegistryMSOps g_RangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator);
RegistryMSOps g_RankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator);
RegistryMSOps g_ReciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator);
RegistryMSOps g_RealDivPrimitiveCreatorRegistry("RealDiv", RealDivPrimitiveCreator);
RegistryMSOps g_ReducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCreator);
RegistryMSOps g_ReduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator);
RegistryMSOps g_ReshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator);
RegistryMSOps g_ResizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator);
RegistryMSOps g_ReverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator);
RegistryMSOps g_ReverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator);
RegistryMSOps g_RfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator);
RegistryMSOps g_ROIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator);
RegistryMSOps g_RoundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator);
RegistryMSOps g_RsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator);
RegistryMSOps g_QuantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator);
RegistryMSOps g_ScalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator);
RegistryMSOps g_ScaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator);
RegistryMSOps g_ShapePrimitiveCreatorRegistry("Shape", ShapePrimitiveCreator);
RegistryMSOps g_SigmoidCrossEntropyWithLogitsPrimitiveCreatorRegistry("SigmoidCrossEntropyWithLogits",
RegistryMSOps g_expandDimsPrimitiveCreatorRegistry("ExpandDims", ExpandDimsPrimitiveCreator);
RegistryMSOps g_expPrimitiveCreatorRegistry("Exp", ExpFusionPrimitiveCreator);
RegistryMSOps g_expFusionPrimitiveCreatorRegistry("ExpFusion", ExpFusionPrimitiveCreator);
RegistryMSOps g_fftImagPrimitiveCreatorRegistry("FftImag", FftImagPrimitiveCreator);
RegistryMSOps g_fftRealPrimitiveCreatorRegistry("FftReal", FftRealPrimitiveCreator);
RegistryMSOps g_fillPrimitiveCreatorRegistry("Fill", FillPrimitiveCreator);
RegistryMSOps g_flattenPrimitiveCreatorRegistry("Flatten", FlattenPrimitiveCreator);
RegistryMSOps g_flattenGradPrimitiveCreatorRegistry("FlattenGrad", FlattenGradPrimitiveCreator);
RegistryMSOps g_floorPrimitiveCreatorRegistry("Floor", FloorPrimitiveCreator);
RegistryMSOps g_floorDivPrimitiveCreatorRegistry("FloorDiv", FloorDivPrimitiveCreator);
RegistryMSOps g_floorModPrimitiveCreatorRegistry("FloorMod", FloorModPrimitiveCreator);
RegistryMSOps g_fullConnectionPrimitiveCreatorRegistry("FullConnection", FullConnectionPrimitiveCreator);
RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator);
RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator);
RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator);
RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator);
RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator);
RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator);
RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator);
RegistryMSOps g_layerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator);
RegistryMSOps g_layerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator);
RegistryMSOps g_leakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator);
RegistryMSOps g_lessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator);
RegistryMSOps g_lessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator);
RegistryMSOps g_logPrimitiveCreatorRegistry("Log", LogPrimitiveCreator);
RegistryMSOps g_logGradPrimitiveCreatorRegistry("LogGrad", LogGradPrimitiveCreator);
RegistryMSOps g_logicalAndPrimitiveCreatorRegistry("LogicalAnd", LogicalAndPrimitiveCreator);
RegistryMSOps g_logicalNotPrimitiveCreatorRegistry("LogicalNot", LogicalNotPrimitiveCreator);
RegistryMSOps g_logicalOrPrimitiveCreatorRegistry("LogicalOr", LogicalOrPrimitiveCreator);
RegistryMSOps g_lpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNormalizationPrimitiveCreator);
RegistryMSOps g_lrnPrimitiveCreatorRegistry("Lrn", LrnPrimitiveCreator);
RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator);
RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator);
RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator);
RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator);
RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator);
RegistryMSOps g_maximumGradPrimitiveCreatorRegistry("MaximumGrad", MaximumGradPrimitiveCreator);
RegistryMSOps g_maxPoolPrimitiveCreatorRegistry("MaxPool", MaxPoolFusionPrimitiveCreator);
RegistryMSOps g_maxPoolFusionPrimitiveCreatorRegistry("MaxPoolFusion", MaxPoolFusionPrimitiveCreator);
RegistryMSOps g_mergePrimitiveCreatorRegistry("Merge", MergePrimitiveCreator);
RegistryMSOps g_minimumPrimitiveCreatorRegistry("Minimum", MinimumPrimitiveCreator);
RegistryMSOps g_minimumGradPrimitiveCreatorRegistry("MinimumGrad", MinimumGradPrimitiveCreator);
RegistryMSOps g_modPrimitiveCreatorRegistry("Mod", ModPrimitiveCreator);
RegistryMSOps g_mulPrimitiveCreatorRegistry("Mul", MulFusionPrimitiveCreator);
RegistryMSOps g_mulMulFusionPrimitiveCreatorRegistry("MulFusion", MulFusionPrimitiveCreator);
RegistryMSOps g_mulGradPrimitiveCreatorRegistry("MulGrad", MulGradPrimitiveCreator);
RegistryMSOps g_negPrimitiveCreatorRegistry("Neg", NegPrimitiveCreator);
RegistryMSOps g_negGradPrimitiveCreatorRegistry("NegGrad", NegGradPrimitiveCreator);
RegistryMSOps g_nonMaxSuppressionPrimitiveCreatorRegistry("NonMaxSuppression", NonMaxSuppressionPrimitiveCreator);
RegistryMSOps g_notEqualPrimitiveCreatorRegistry("NotEqual", NotEqualPrimitiveCreator);
RegistryMSOps g_oneHotPrimitiveCreatorRegistry("OneHot", OneHotPrimitiveCreator);
RegistryMSOps g_onesLikePrimitiveCreatorRegistry("OnesLike", OnesLikePrimitiveCreator);
RegistryMSOps g_padPrimitiveCreatorRegistry("Pad", PadFusionPrimitiveCreator);
RegistryMSOps g_padFusionPrimitiveCreatorRegistry("PadFusion", PadFusionPrimitiveCreator);
RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFusionPrimitiveCreator);
RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator);
RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator);
RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator);
RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator);
RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator);
RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator);
RegistryMSOps g_realDivPrimitiveCreatorRegistry("RealDiv", RealDivPrimitiveCreator);
RegistryMSOps g_reducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCreator);
RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator);
RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator);
RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator);
RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator);
RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator);
RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator);
RegistryMSOps g_rOIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator);
RegistryMSOps g_roundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator);
RegistryMSOps g_rsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator);
RegistryMSOps g_quantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator);
RegistryMSOps g_scalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator);
RegistryMSOps g_scaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator);
RegistryMSOps g_shapePrimitiveCreatorRegistry("Shape", ShapePrimitiveCreator);
RegistryMSOps g_sigmoidCrossEntropyWithLogitsPrimitiveCreatorRegistry("SigmoidCrossEntropyWithLogits",
SigmoidCrossEntropyWithLogitsPrimitiveCreator);
RegistryMSOps g_SigmoidCrossEntropyWithLogitsGradPrimitiveCreatorRegistry(
RegistryMSOps g_sigmoidCrossEntropyWithLogitsGradPrimitiveCreatorRegistry(
"SigmoidCrossEntropyWithLogitsGrad", SigmoidCrossEntropyWithLogitsGradPrimitiveCreator);
RegistryMSOps g_SinPrimitiveCreatorRegistry("Sin", SinPrimitiveCreator);
RegistryMSOps g_SkipGramPrimitiveCreatorRegistry("SkipGram", SkipGramPrimitiveCreator);
RegistryMSOps g_SliceFusionPrimitiveCreatorRegistry("SliceFusion", SliceFusionPrimitiveCreator);
RegistryMSOps g_SmoothL1LossPrimitiveCreatorRegistry("SmoothL1Loss", SmoothL1LossPrimitiveCreator);
RegistryMSOps g_SmoothL1LossGradPrimitiveCreatorRegistry("SmoothL1LossGrad", SmoothL1LossGradPrimitiveCreator);
RegistryMSOps g_SoftmaxPrimitiveCreatorRegistry("Softmax", SoftmaxPrimitiveCreator);
RegistryMSOps g_SpaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator);
RegistryMSOps g_SpaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator);
RegistryMSOps g_SpaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator);
RegistryMSOps g_SparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator);
RegistryMSOps g_SplitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator);
RegistryMSOps g_SqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator);
RegistryMSOps g_SqueezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator);
RegistryMSOps g_SquarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator);
RegistryMSOps g_SquaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator);
RegistryMSOps g_StackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator);
RegistryMSOps g_StridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator);
RegistryMSOps g_SubPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator);
RegistryMSOps g_SubFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator);
RegistryMSOps g_SubGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator);
RegistryMSOps g_SwitchPrimitiveCreatorRegistry("Switch", SwitchPrimitiveCreator);
RegistryMSOps g_TensorListFromTensorPrimitiveCreatorRegistry("TensorListFromTensor",
RegistryMSOps g_sinPrimitiveCreatorRegistry("Sin", SinPrimitiveCreator);
RegistryMSOps g_skipGramPrimitiveCreatorRegistry("SkipGram", SkipGramPrimitiveCreator);
RegistryMSOps g_sliceFusionPrimitiveCreatorRegistry("SliceFusion", SliceFusionPrimitiveCreator);
RegistryMSOps g_smoothL1LossPrimitiveCreatorRegistry("SmoothL1Loss", SmoothL1LossPrimitiveCreator);
RegistryMSOps g_smoothL1LossGradPrimitiveCreatorRegistry("SmoothL1LossGrad", SmoothL1LossGradPrimitiveCreator);
RegistryMSOps g_softmaxPrimitiveCreatorRegistry("Softmax", SoftmaxPrimitiveCreator);
RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator);
RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator);
RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator);
RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator);
RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator);
RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator);
RegistryMSOps g_squeezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator);
RegistryMSOps g_squarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator);
RegistryMSOps g_squaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator);
RegistryMSOps g_stackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator);
RegistryMSOps g_stridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator);
RegistryMSOps g_subPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator);
RegistryMSOps g_subFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator);
RegistryMSOps g_subGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator);
RegistryMSOps g_switchPrimitiveCreatorRegistry("Switch", SwitchPrimitiveCreator);
RegistryMSOps g_tensorListFromTensorPrimitiveCreatorRegistry("TensorListFromTensor",
TensorListFromTensorPrimitiveCreator);
RegistryMSOps g_TensorListGetItemPrimitiveCreatorRegistry("TensorListGetItem", TensorListGetItemPrimitiveCreator);
RegistryMSOps g_TensorListReservePrimitiveCreatorRegistry("TensorListReserve", TensorListReservePrimitiveCreator);
RegistryMSOps g_TensorListSetItemPrimitiveCreatorRegistry("TensorListSetItem", TensorListSetItemPrimitiveCreator);
RegistryMSOps g_TensorListStackPrimitiveCreatorRegistry("TensorListStack", TensorListStackPrimitiveCreator);
RegistryMSOps g_TileFusionPrimitiveCreatorRegistry("TileFusion", TileFusionPrimitiveCreator);
RegistryMSOps g_TopKPrimitiveCreatorRegistry("TopK", TopKFusionPrimitiveCreator);
RegistryMSOps g_TopKFusionPrimitiveCreatorRegistry("TopKFusion", TopKFusionPrimitiveCreator);
RegistryMSOps g_TransposePrimitiveCreatorxRegistry("Transpose", TransposePrimitiveCreator);
RegistryMSOps g_UniquePrimitiveCreatorRegistry("Unique", UniquePrimitiveCreator);
RegistryMSOps g_UnpackPrimitiveCreatorRegistry("Unpack", UnpackPrimitiveCreator);
RegistryMSOps g_UnsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum", UnsortedSegmentSumPrimitiveCreator);
RegistryMSOps g_UnsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator);
RegistryMSOps g_WherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator);
RegistryMSOps g_ZerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator);
RegistryMSOps g_tensorListGetItemPrimitiveCreatorRegistry("TensorListGetItem", TensorListGetItemPrimitiveCreator);
RegistryMSOps g_tensorListReservePrimitiveCreatorRegistry("TensorListReserve", TensorListReservePrimitiveCreator);
RegistryMSOps g_tensorListSetItemPrimitiveCreatorRegistry("TensorListSetItem", TensorListSetItemPrimitiveCreator);
RegistryMSOps g_tensorListStackPrimitiveCreatorRegistry("TensorListStack", TensorListStackPrimitiveCreator);
RegistryMSOps g_tileFusionPrimitiveCreatorRegistry("TileFusion", TileFusionPrimitiveCreator);
RegistryMSOps g_topKPrimitiveCreatorRegistry("TopK", TopKFusionPrimitiveCreator);
RegistryMSOps g_topKFusionPrimitiveCreatorRegistry("TopKFusion", TopKFusionPrimitiveCreator);
RegistryMSOps g_transposePrimitiveCreatorxRegistry("Transpose", TransposePrimitiveCreator);
RegistryMSOps g_uniquePrimitiveCreatorRegistry("Unique", UniquePrimitiveCreator);
RegistryMSOps g_unpackPrimitiveCreatorRegistry("Unpack", UnpackPrimitiveCreator);
RegistryMSOps g_unsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum", UnsortedSegmentSumPrimitiveCreator);
RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator);
RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator);
RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator);
} // namespace lite
} // namespace mindspore



+ 14
- 30
mindspore/lite/tools/converter/parser/caffe/caffe_activation_parser.cc View File

@@ -21,59 +21,43 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Activation();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ReLU failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::RELU);
prim->set_activation_type(mindspore::ActivationType::RELU);

if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) {
float negative_slope = proto.relu_param().negative_slope();
if (negative_slope != 0) {
primitive_c->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
primitive_c->set_alpha(negative_slope);
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
prim->set_alpha(negative_slope);
}
}

return primitive_c;
return prim.release();
}

ops::PrimitiveC *CaffeRelu6Parser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Activation();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Relu6 failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::RELU6);
prim->set_activation_type(mindspore::ActivationType::RELU6);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *CaffeSigmoidParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Activation();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Sigmoid failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::SIGMOID);
prim->set_activation_type(mindspore::ActivationType::SIGMOID);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *CaffeTanhParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Activation();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Tanh failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::TANH);
prim->set_activation_type(mindspore::ActivationType::TANH);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeReluParser("ReLU", new CaffeReluParser());


+ 8
- 12
mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc View File

@@ -21,28 +21,24 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::ArgMaxFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ArgMaxFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ArgMaxFusion>();

primitive_c->set_keep_dims(true);
primitive_c->set_out_max_value(false);
primitive_c->set_top_k(1);
prim->set_keep_dims(true);
prim->set_out_max_value(false);
prim->set_top_k(1);

const caffe::ArgMaxParameter &argmaxParam = proto.argmax_param();
if (argmaxParam.has_out_max_val()) {
primitive_c->set_out_max_value(argmaxParam.out_max_val());
prim->set_out_max_value(argmaxParam.out_max_val());
}
if (argmaxParam.has_top_k()) {
primitive_c->set_top_k(argmaxParam.top_k());
prim->set_top_k(argmaxParam.top_k());
}
if (argmaxParam.has_axis()) {
primitive_c->set_axis(argmaxParam.axis());
prim->set_axis(argmaxParam.axis());
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeArgMaxParser("ArgMax", new CaffeArgMaxParser());


+ 6
- 10
mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc View File

@@ -24,11 +24,10 @@ namespace mindspore {
namespace lite {
using STATUS = int;
ops::PrimitiveC *CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::BatchNorm();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new BatchNorm failed";
return nullptr;
}
auto prim = std::make_unique<ops::BatchNorm>();

prim->set_is_training(false);
prim->set_format(mindspore::NCHW);

const caffe::BatchNormParameter &batchNormParam = proto.batch_norm_param();
if (proto.bottom_size() != 1) {
@@ -46,12 +45,9 @@ ops::PrimitiveC *CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto,
if (batchNormParam.has_eps() && std::fabs(1e-5 - batchNormParam.eps()) >= 1e-9) {
epsilon = batchNormParam.eps();
}
primitive_c->set_epsilon(epsilon);

primitive_c->set_is_training(false);
primitive_c->set_format(mindspore::NCHW);
prim->set_epsilon(epsilon);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeBatchNormParser("BatchNorm", new CaffeBatchNormParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Concat();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Concat failed";
return nullptr;
}
auto prim = std::make_unique<ops::Concat>();

const caffe::ConcatParameter &concatParam = proto.concat_param();
if (concatParam.has_axis() && concatParam.has_concat_dim()) {
@@ -45,9 +41,9 @@ ops::PrimitiveC *CaffeConcatParser::Parse(const caffe::LayerParameter &proto, co
MS_LOG(DEBUG) << "set axis: " << concatParam.axis();
axis = concatParam.axis();
}
primitive_c->set_axis(axis);
prim->set_axis(axis);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeConcatParser("Concat", new CaffeConcatParser());


+ 16
- 23
mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc View File

@@ -22,61 +22,52 @@ namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto,
const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Conv2DFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Conv2DFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::Conv2DFusion>();

primitive_c->set_pad({0, 0, 0, 0});
primitive_c->set_pad_mode(mindspore::PadMode::PAD);
primitive_c->set_format(mindspore::Format::NCHW);
primitive_c->set_activation_type(mindspore::NO_ACTIVATION);
prim->set_pad({0, 0, 0, 0});
prim->set_pad_mode(mindspore::PadMode::PAD);
prim->set_format(mindspore::Format::NCHW);
prim->set_activation_type(mindspore::NO_ACTIVATION);

const caffe::ConvolutionParameter &convParam = proto.convolution_param();
// parse kernel
std::vector<int64_t> kernel(2, 0);
if (CaffeConvBaseParser::ParseKernels(convParam, &kernel) != RET_OK) {
MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_kernel_size(kernel);
prim->set_kernel_size(kernel);

// parse stride
std::vector<int64_t> stride(2, 0);
if (CaffeConvBaseParser::ParseStrides(convParam, &stride) != RET_OK) {
MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_stride(stride);
prim->set_stride(stride);

// parse dilation
std::vector<int64_t> dilation(2, 0);
if (CaffeConvBaseParser::ParseDilations(convParam, &dilation) != RET_OK) {
MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_dilation(dilation);
prim->set_dilation(dilation);

// parse pad
std::vector<int64_t> pad(4, 0);
if (CaffeConvBaseParser::ParsePads(convParam, &pad) != RET_OK) {
MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_pad_list(pad);
prim->set_pad_list(pad);

// parse channelOut
int channel_out = 0;
if (CaffeConvBaseParser::ParseChannelOut(convParam, &channel_out) != RET_OK) {
MS_LOG(ERROR) << "conv channel out failed";
return nullptr;
}
primitive_c->set_out_channel(channel_out);
prim->set_out_channel(channel_out);

// parse group
auto group = CaffeConvBaseParser::ParseGroup(convParam, proto.type());
primitive_c->set_group(group);
prim->set_group(group);

// parse channelIn
if (weight.blobs_size() < 1) {
@@ -85,11 +76,13 @@ ops::PrimitiveC *CaffeConvolutionParser::Parse(const caffe::LayerParameter &prot
}
auto &weightBlob = weight.blobs(0);
auto channelIn = weightBlob.has_shape() ? weightBlob.shape().dim(1) * group : weightBlob.channels() * group;
primitive_c->set_in_channel(channelIn);
prim->set_in_channel(channelIn);

if (group != 1) {
primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
}
return primitive_c;

return prim.release();
}

CaffeNodeRegistrar g_caffeConvolutionParser("Convolution", new CaffeConvolutionParser());


+ 7
- 11
mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc View File

@@ -21,25 +21,21 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Crop();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Crop failed";
return nullptr;
}
auto prim = std::make_unique<ops::Crop>();

if (!proto.has_crop_param()) {
primitive_c->set_axis(2);
prim->set_axis(2);
std::vector<int64_t> offsets(2, 0);
primitive_c->set_offsets(offsets);
prim->set_offsets(offsets);
} else {
const caffe::CropParameter &cropParam = proto.crop_param();
if (cropParam.has_axis()) {
if (cropParam.axis() == -1) {
MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims.";
}
primitive_c->set_axis(cropParam.axis());
prim->set_axis(cropParam.axis());
} else {
primitive_c->set_axis(2);
prim->set_axis(2);
}

if (cropParam.offset_size() != 0) {
@@ -48,11 +44,11 @@ ops::PrimitiveC *CaffeCropParser::Parse(const caffe::LayerParameter &proto, cons
for (int i = 0; i < cropParam.offset_size(); i++) {
offsets.push_back(cropParam.offset(i));
}
primitive_c->set_offsets(offsets);
prim->set_offsets(offsets);
}
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeCropParser("Crop", new CaffeCropParser());


+ 16
- 25
mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc View File

@@ -20,78 +20,69 @@

namespace mindspore {
namespace lite {

ops::PrimitiveC *CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto,
const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Conv2dTransposeFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Conv2dTransposeFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::Conv2dTransposeFusion>();

primitive_c->set_pad({0, 0, 0, 0});
primitive_c->set_format(mindspore::Format::NCHW);
primitive_c->set_pad_mode(mindspore::PadMode::PAD);
prim->set_pad({0, 0, 0, 0});
prim->set_format(mindspore::Format::NCHW);
prim->set_pad_mode(mindspore::PadMode::PAD);

const caffe::ConvolutionParameter &convParam = proto.convolution_param();
// parse pad
std::vector<int64_t> pad(4, 0);
if (CaffeConvBaseParser::ParsePads(convParam, &pad) != RET_OK) {
MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_pad_list({pad[0], pad[1], pad[2], pad[3]});
prim->set_pad_list({pad[0], pad[1], pad[2], pad[3]});

// parse stride
std::vector<int64_t> stride(2, 0);
if (CaffeConvBaseParser::ParseStrides(convParam, &stride) != RET_OK) {
MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_stride({stride[0], stride[1]});
prim->set_stride({stride[0], stride[1]});

// parse dilation
std::vector<int64_t> dilation(2, 0);
if (CaffeConvBaseParser::ParseDilations(convParam, &dilation) != RET_OK) {
MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_dilation({dilation[0], dilation[1]});
prim->set_dilation({dilation[0], dilation[1]});

// parse kernel
std::vector<int64_t> kernel(2, 0);
if (CaffeConvBaseParser::ParseKernels(convParam, &kernel) != RET_OK) {
MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed";
return nullptr;
}
primitive_c->set_kernel_size({kernel[0], kernel[1]});
prim->set_kernel_size({kernel[0], kernel[1]});

// parse group
auto group = CaffeConvBaseParser::ParseGroup(convParam, proto.type());
primitive_c->set_group(group);
prim->set_group(group);

// parse channelOut
int32_t channelOut;
if (CaffeConvBaseParser::ParseChannelOut(convParam, &channelOut) != RET_OK) {
MS_LOG(ERROR) << "deconv channel get failed";
return nullptr;
}
primitive_c->set_out_channel((int64_t)channelOut);
prim->set_out_channel((int64_t)channelOut);

// parse channelIN
auto &weightBlob = weight.blobs(0);
if (weightBlob.has_shape()) {
if (group == 1)
primitive_c->set_in_channel(weightBlob.shape().dim(0) * group);
prim->set_in_channel(weightBlob.shape().dim(0) * group);
else
primitive_c->set_in_channel(weightBlob.shape().dim(1) * group);
prim->set_in_channel(weightBlob.shape().dim(1) * group);
} else {
primitive_c->set_in_channel(weightBlob.num() * group);
prim->set_in_channel(weightBlob.num() * group);
}
if (group != 1) {
primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
}
return primitive_c;

return prim.release();
}

CaffeNodeRegistrar g_caffeDeconvolutionParser("Deconvolution", new CaffeDeconvolutionParser());


+ 6
- 10
mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Eltwise();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Eltwise failed";
return nullptr;
}
auto prim = std::make_unique<ops::Eltwise>();

if (proto.bottom_size() < 2) {
MS_LOG(ERROR) << "Eltwise Op " << proto.name() << " need at least 2 inputs,but input size is "
@@ -55,23 +51,23 @@ ops::PrimitiveC *CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, c
if (proto.has_eltwise_param() && eltwiseParam.has_operation()) {
switch (eltwiseParam.operation()) {
case caffe::EltwiseParameter::PROD:
primitive_c->set_mode(mindspore::EltwiseMode::PROD);
prim->set_mode(mindspore::EltwiseMode::PROD);
break;
case caffe::EltwiseParameter::SUM:
primitive_c->set_mode(mindspore::EltwiseMode::SUM);
prim->set_mode(mindspore::EltwiseMode::SUM);
break;
case caffe::EltwiseParameter::MAX:
primitive_c->set_mode(mindspore::EltwiseMode::MAXIMUM);
prim->set_mode(mindspore::EltwiseMode::MAXIMUM);
break;
default:
MS_LOG(ERROR) << "Eltwise parse params fail, unsupported operation: " << eltwiseParam.operation();
return nullptr;
}
} else {
primitive_c->set_mode(mindspore::EltwiseMode::SUM);
prim->set_mode(mindspore::EltwiseMode::SUM);
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeEltwiseParser("Eltwise", new CaffeEltwiseParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Elu();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Elu failed";
return nullptr;
}
auto prim = std::make_unique<ops::Elu>();

if (proto.has_elu_param()) {
const caffe::ELUParameter &eluParameter = proto.elu_param();
if (eluParameter.has_alpha()) {
primitive_c->set_alpha(eluParameter.alpha());
prim->set_alpha(eluParameter.alpha());
}
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser());


+ 9
- 12
mindspore/lite/tools/converter/parser/caffe/caffe_exp_parser.cc View File

@@ -22,29 +22,26 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeExpParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::ExpFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ExpFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ExpFusion>();

const caffe::ExpParameter &exp_param = proto.exp_param();
if (exp_param.has_base()) {
primitive_c->set_base(exp_param.base());
prim->set_base(exp_param.base());
} else {
primitive_c->set_base(-1); // -1 represent base = e
prim->set_base(-1); // -1 represent base = e
}
if (exp_param.has_scale()) {
primitive_c->set_scale(exp_param.scale());
prim->set_scale(exp_param.scale());
} else {
primitive_c->set_scale(1);
prim->set_scale(1);
}
if (exp_param.has_shift()) {
primitive_c->set_shift(exp_param.shift());
prim->set_shift(exp_param.shift());
} else {
primitive_c->set_shift(0);
prim->set_shift(0);
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeExpParser("Exp", new CaffeExpParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc View File

@@ -21,13 +21,9 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Flatten();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Flatten failed";
return nullptr;
}
auto prim = std::make_unique<ops::Flatten>();

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_CaffeFlattenParser("Flatten", new CaffeFlattenParser());


+ 7
- 11
mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc View File

@@ -22,11 +22,9 @@ namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto,
const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::FullConnection();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new FullConnection failed";
return nullptr;
}
auto prim = std::make_unique<ops::FullConnection>();

prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);

const caffe::InnerProductParameter &innerProductParam = proto.inner_product_param();
if (!innerProductParam.has_num_output()) {
@@ -35,19 +33,17 @@ ops::PrimitiveC *CaffeInnerProductParser::Parse(const caffe::LayerParameter &pro
}

if (innerProductParam.axis() == 1) {
primitive_c->set_axis(1);
primitive_c->set_use_axis(true);
prim->set_axis(1);
prim->set_use_axis(true);
} else {
MS_LOG(ERROR) << "InnerProduct Parse axis only support default 1, but actually " << innerProductParam.axis();
return nullptr;
}
if (innerProductParam.bias_term()) {
primitive_c->set_has_bias(true);
prim->set_has_bias(true);
}

primitive_c->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeInnerProductParser("InnerProduct", new CaffeInnerProductParser());


+ 7
- 10
mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc View File

@@ -21,11 +21,10 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Resize();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Resize failed";
return nullptr;
}
auto prim = std::make_unique<ops::Resize>();

prim->set_method(mindspore::ResizeMethod::LINEAR);
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ALIGN_CORNERS);

const caffe::InterpParameter &interpParam = proto.interp_param();
if (interpParam.has_height()) {
@@ -34,7 +33,7 @@ ops::PrimitiveC *CaffeInterpParser::Parse(const caffe::LayerParameter &proto, co
MS_LOG(ERROR) << "Interp height must be > 0";
return nullptr;
}
primitive_c->set_new_height(height);
prim->set_new_height(height);
}

if (interpParam.has_width()) {
@@ -43,12 +42,10 @@ ops::PrimitiveC *CaffeInterpParser::Parse(const caffe::LayerParameter &proto, co
MS_LOG(ERROR) << "Interp width must be > 0";
return nullptr;
}
primitive_c->set_new_width(width);
prim->set_new_width(width);
}
primitive_c->set_method(mindspore::ResizeMethod::LINEAR);
primitive_c->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ALIGN_CORNERS);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeInterpParser("Interp", new CaffeInterpParser());


+ 8
- 8
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc View File

@@ -95,8 +95,8 @@ STATUS CaffeModelParser::ConvertLayers() {
continue;
}

auto primitive_c = node_parser->Parse(layer, weight);
if (primitive_c == nullptr) {
auto prim = node_parser->Parse(layer, weight);
if (prim == nullptr) {
MS_LOG(ERROR) << "parse node " << layer.name() << " failed.";
status = RET_ERROR;
continue;
@@ -119,7 +119,7 @@ STATUS CaffeModelParser::ConvertLayers() {
}

// build cnode
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(prim))};
op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end());
op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end());
auto new_cnode = func_graph_ptr_->NewCNode(op_inputs);
@@ -132,7 +132,7 @@ STATUS CaffeModelParser::ConvertLayers() {
continue;
}

status = ConvertLayerQuantParams(layer, weight, primitive_c);
status = ConvertLayerQuantParams(layer, weight, prim);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed.";
continue;
@@ -294,9 +294,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
}

STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &layer,
const caffe::LayerParameter &weight, ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
const caffe::LayerParameter &weight, ops::PrimitiveC *prim) {
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is null, get quant params failed.";
return RET_NULL_PTR;
}
auto quant_params_holder = std::make_shared<QuantParamHolder>();
@@ -312,7 +312,7 @@ STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &la
std::vector<schema::QuantParamT> notinited_quant_params(1);
quant_params_holder->AddOutputQuantParam(notinited_quant_params);
}
primitive_c->AddAttr("quant_params", quant_params_holder);
prim->AddAttr("quant_params", quant_params_holder);
return RET_OK;
}



+ 1
- 1
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h View File

@@ -45,7 +45,7 @@ class CaffeModelParser : public ModelParser {
STATUS ConvertLayers();

static STATUS ConvertLayerQuantParams(const caffe::LayerParameter &layer, const caffe::LayerParameter &weight,
ops::PrimitiveC *primitive_c);
ops::PrimitiveC *prim);

STATUS ConvertBlobs(const caffe::LayerParameter &layer, std::vector<ParameterPtr> *const_parameters);



+ 3
- 7
mindspore/lite/tools/converter/parser/caffe/caffe_permute_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffePermuteParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Transpose();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Transpose failed";
return nullptr;
}
auto prim = std::make_unique<ops::Transpose>();

std::vector<int32_t> perm;
const caffe::PermuteParameter &permuteParam = proto.permute_param();
@@ -34,9 +30,9 @@ ops::PrimitiveC *CaffePermuteParser::Parse(const caffe::LayerParameter &proto, c
for (int i = 0; i < num_order_dims; ++i) {
perm[i] = permuteParam.order()[i];
}
primitive_c->AddAttr("perm", MakeValue(perm));
prim->AddAttr("perm", MakeValue(perm));

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffePermuteParser("Permute", new CaffePermuteParser());


+ 16
- 24
mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc View File

@@ -124,31 +124,23 @@ ops::PrimitiveC *CaffePoolingParser::Parse(const caffe::LayerParameter &proto, c
auto roundMode = ParseRoundMode(poolingParam);

if (poolingParam.pool() == caffe::PoolingParameter::MAX) {
auto primitive_c = new (std::nothrow) ops::MaxPoolFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MaxPoolFusion failed";
return nullptr;
}
primitive_c->set_format(mindspore::Format::NCHW);
primitive_c->set_pad_mode(mindspore::PadMode::PAD);
primitive_c->set_kernel_size(windows);
primitive_c->set_strides(strides);
primitive_c->set_pad(pad);
primitive_c->set_round_mode(roundMode);
return primitive_c;
auto prim = std::make_unique<ops::MaxPoolFusion>();
prim->set_format(mindspore::Format::NCHW);
prim->set_pad_mode(mindspore::PadMode::PAD);
prim->set_kernel_size(windows);
prim->set_strides(strides);
prim->set_pad(pad);
prim->set_round_mode(roundMode);
return prim.release();
} else if (poolingParam.pool() == caffe::PoolingParameter::AVE) {
auto primitive_c = new (std::nothrow) ops::AvgPoolFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new AvgPoolFusion failed";
return nullptr;
}
primitive_c->set_format(mindspore::Format::NCHW);
primitive_c->set_pad_mode(mindspore::PadMode::PAD);
primitive_c->set_kernel_size(windows);
primitive_c->set_strides(strides);
primitive_c->set_pad(pad);
primitive_c->set_round_mode(roundMode);
return primitive_c;
auto prim = std::make_unique<ops::AvgPoolFusion>();
prim->set_format(mindspore::Format::NCHW);
prim->set_pad_mode(mindspore::PadMode::PAD);
prim->set_kernel_size(windows);
prim->set_strides(strides);
prim->set_pad(pad);
prim->set_round_mode(roundMode);
return prim.release();
} else {
MS_LOG(ERROR) << "poolingParam.pool() is not MAX or AVE";
return nullptr;


+ 5
- 9
mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::PowFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new PowFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::PowFusion>();

const caffe::PowerParameter &powerParam = proto.power_param();
float power = 1.0;
@@ -42,11 +38,11 @@ ops::PrimitiveC *CaffePowerParser::Parse(const caffe::LayerParameter &proto, con
shift = powerParam.shift();
}
}
primitive_c->AddAttr("power", MakeValue(power));
primitive_c->set_scale(scale);
primitive_c->set_shift(shift);
prim->AddAttr("power", MakeValue(power));
prim->set_scale(scale);
prim->set_shift(shift);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffePowerParser("Power", new CaffePowerParser());


+ 6
- 10
mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::PReLUFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new PReLUFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::PReLUFusion>();

const caffe::PReLUParameter &pReluParam = proto.prelu_param();
if (pReluParam.has_channel_shared()) {
primitive_c->set_channel_shared(pReluParam.channel_shared());
const caffe::PReLUParameter &prelu_param = proto.prelu_param();
if (prelu_param.has_channel_shared()) {
prim->set_channel_shared(prelu_param.channel_shared());
} else {
primitive_c->set_channel_shared(false);
prim->set_channel_shared(false);
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffePReluParser("PReLU", new CaffePReluParser());


+ 9
- 13
mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc View File

@@ -22,30 +22,26 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::ReduceFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ReduceFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ReduceFusion>();

primitive_c->set_keep_dims(false);
prim->set_keep_dims(false);

const caffe::ReductionParameter &reduce_param = proto.reduction_param();
if (reduce_param.has_operation()) {
if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_MEAN) {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Mean);
prim->set_mode(mindspore::ReduceMode::Reduce_Mean);
} else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUM) {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum);
prim->set_mode(mindspore::ReduceMode::Reduce_Sum);
} else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUMSQ) {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum_Square);
prim->set_mode(mindspore::ReduceMode::Reduce_Sum_Square);
} else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_ASUM) {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_ASum);
prim->set_mode(mindspore::ReduceMode::Reduce_ASum);
} else {
MS_LOG(ERROR) << "nsupported reduce mode: " << reduce_param.operation();
return nullptr;
}
} else {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum);
prim->set_mode(mindspore::ReduceMode::Reduce_Sum);
}

std::vector<int32_t> axes;
@@ -56,9 +52,9 @@ ops::PrimitiveC *CaffeReduceParser::Parse(const caffe::LayerParameter &proto, co
axes.push_back(1);
axes.push_back(0);
}
primitive_c->AddAttr("axes", MakeValue(axes));
prim->AddAttr("axes", MakeValue(axes));

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Reshape();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Reshape failed";
return nullptr;
}
auto prim = std::make_unique<ops::Reshape>();

const caffe::ReshapeParameter &reshapeParam = proto.reshape_param();
if (!reshapeParam.has_shape()) {
@@ -37,9 +33,9 @@ ops::PrimitiveC *CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, c
for (int i = 0; i < blob_shape.dim_size(); i++) {
shape.push_back(blob_shape.dim(i));
}
primitive_c->AddAttr("shape", MakeValue(shape));
prim->AddAttr("shape", MakeValue(shape));

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeReshapeParser("Reshape", new CaffeReshapeParser());


+ 3
- 21
mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc View File

@@ -20,26 +20,8 @@

namespace mindspore {
namespace lite {
STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) {
if (axis < -4 || axis >= 4) {
MS_LOG(ERROR) << "Scale axis value(" << axis << ") is not correct";
return RET_ERROR;
}

if (axis == -1) {
MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims.";
}

*axis_index = (axis + 4) % 4;
return RET_OK;
}

ops::PrimitiveC *CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::ScaleFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ScaleFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ScaleFusion>();

if (weight.blobs_size() + weight.bottom_size() < 2) {
MS_LOG(ERROR) << "Scale bottom size:" << weight.bottom_size() << ", blobs size:" << weight.blobs_size()
@@ -58,9 +40,9 @@ ops::PrimitiveC *CaffeScaleParser::Parse(const caffe::LayerParameter &proto, con
MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims.";
}
}
primitive_c->set_axis(1);
prim->set_axis(1);

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeScaleParser("Scale", new CaffeScaleParser());


+ 0
- 2
mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h View File

@@ -29,8 +29,6 @@ class CaffeScaleParser : public CaffeNodeParser {
~CaffeScaleParser() override = default;

ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override;

static STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index);
};
} // namespace lite
} // namespace mindspore


+ 7
- 11
mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.cc View File

@@ -21,16 +21,12 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Split();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Split failed";
return nullptr;
}
auto prim = std::make_unique<ops::Split>();

const caffe::SliceParameter &slice_param = proto.slice_param();
primitive_c->set_output_num(2);
prim->set_output_num(2);
if (!slice_param.slice_point().empty()) {
primitive_c->set_output_num(slice_param.slice_point_size() + 1);
prim->set_output_num(slice_param.slice_point_size() + 1);
std::vector<int64_t> size_splits;
for (int i = 0; i < slice_param.slice_point_size(); ++i) {
if (i == 0) {
@@ -40,16 +36,16 @@ ops::PrimitiveC *CaffeSliceParser::Parse(const caffe::LayerParameter &proto, con
}
}
size_splits.push_back(-1);
primitive_c->set_size_splits(size_splits);
prim->set_size_splits(size_splits);
}

if (slice_param.has_axis()) {
primitive_c->set_axis(slice_param.axis());
prim->set_axis(slice_param.axis());
} else if (slice_param.has_slice_dim()) {
primitive_c->set_axis(slice_param.slice_dim());
prim->set_axis(slice_param.slice_dim());
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeSliceParser("Slice", new CaffeSliceParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc View File

@@ -21,22 +21,18 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::Softmax();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Softmax failed";
return nullptr;
}
auto prim = std::make_unique<ops::Softmax>();

if (proto.has_softmax_param() && proto.softmax_param().has_axis()) {
if (proto.softmax_param().axis() == -1) {
MS_LOG(DEBUG) << "axis with -1 may lead to calculation errors when input less than 4 dims.";
}
primitive_c->set_axis({proto.softmax_param().axis()});
prim->set_axis({proto.softmax_param().axis()});
} else {
primitive_c->set_axis({1});
prim->set_axis({1});
}

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeSoftmaxParser("Softmax", new CaffeSoftmaxParser());


+ 5
- 8
mindspore/lite/tools/converter/parser/caffe/caffe_tile_parser.cc View File

@@ -22,11 +22,8 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *CaffeTileParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
auto primitive_c = new (std::nothrow) ops::TileFusion();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new TileFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::TileFusion>();

const caffe::TileParameter &tile_param = proto.tile_param();
std::vector<int64_t> dims;
dims.clear();
@@ -35,7 +32,7 @@ ops::PrimitiveC *CaffeTileParser::Parse(const caffe::LayerParameter &proto, cons
} else {
dims.push_back(1);
}
primitive_c->set_dims(dims);
prim->set_dims(dims);

std::vector<int32_t> multiples;
multiples.clear();
@@ -44,9 +41,9 @@ ops::PrimitiveC *CaffeTileParser::Parse(const caffe::LayerParameter &proto, cons
} else {
multiples.push_back(1);
}
primitive_c->AddAttr("multiples", MakeValue(multiples));
prim->AddAttr("multiples", MakeValue(multiples));

return primitive_c;
return prim.release();
}

CaffeNodeRegistrar g_caffeTileParser("Tile", new CaffeTileParser());


+ 22
- 46
mindspore/lite/tools/converter/parser/onnx/onnx_activation_parser.cc View File

@@ -25,42 +25,30 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Activation;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ReLU failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::RELU);
prim->set_activation_type(mindspore::ActivationType::RELU);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxLeakyReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Activation;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LeakyRelu failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "alpha") {
primitive_c->set_alpha(onnx_node_attr.f());
prim->set_alpha(onnx_node_attr.f());
}
}

primitive_c->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::PReLUFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new PReLU failed";
return nullptr;
}
auto prim = std::make_unique<ops::PReLUFusion>();

std::vector<onnx::TensorProto> params;
const auto &input_name = onnx_node.input(1);
@@ -82,10 +70,10 @@ ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, cons
const auto slope_raw_data = reinterpret_cast<const float *>(slope_data->raw_data().data());
const int64_t slope_size = slope_data->raw_data().size() / sizeof(float);
std::vector<float> slope;
bool channelShared = false;
bool channel_shared = false;
if (slope_size == 1) {
slope.push_back(*slope_raw_data);
channelShared = true;
channel_shared = true;
} else {
slope.resize(slope_size);
if (memcpy_s(slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) {
@@ -93,54 +81,42 @@ ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, cons
return nullptr;
}
}
primitive_c->set_slope(slope);
primitive_c->set_channel_shared(channelShared);
prim->set_slope(slope);
prim->set_channel_shared(channel_shared);
} else {
MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors.";
}

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Elu;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Elu failed";
return nullptr;
}
auto prim = std::make_unique<ops::Elu>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "alpha") {
primitive_c->set_alpha(onnx_node_attr.f());
prim->set_alpha(onnx_node_attr.f());
}
}

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Activation;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Tanh failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::TANH);
prim->set_activation_type(mindspore::ActivationType::TANH);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Activation;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Sigmoid failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

primitive_c->set_activation_type(mindspore::ActivationType::SIGMOID);
prim->set_activation_type(mindspore::ActivationType::SIGMOID);

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser());


+ 2
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.cc View File

@@ -21,13 +21,8 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::AdderFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new AdderFusion failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::AdderFusion>();
return prim.release();
}

OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc View File

@@ -21,22 +21,18 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::ArgMaxFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ArgMax failed";
return nullptr;
}
auto prim = std::make_unique<ops::ArgMaxFusion>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
primitive_c->set_axis(onnx_node_attr.i());
prim->set_axis(onnx_node_attr.i());
} else if (attribute_name == "keepdims") {
primitive_c->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser());


+ 62
- 199
mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc View File

@@ -50,297 +50,160 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::AddFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new AddFusion failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::AddFusion>();
return prim.release();
}

ops::PrimitiveC *OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::SubFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new SubFusion failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::SubFusion>();
return prim.release();
}

ops::PrimitiveC *OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::DivFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new DivFusion failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::DivFusion>();
return prim.release();
}

ops::PrimitiveC *OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::MulFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MulFusion failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::MulFusion>();
return prim.release();
}

ops::PrimitiveC *OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Equal;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Equal failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Equal>();
return prim.release();
}

ops::PrimitiveC *OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Less;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Less failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Less>();
return prim.release();
}

ops::PrimitiveC *OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Greater;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Greater failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Greater>();
return prim.release();
}

ops::PrimitiveC *OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Floor;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Floor failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Floor>();
return prim.release();
}

ops::PrimitiveC *OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Abs;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Abs failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Abs>();
return prim.release();
}

ops::PrimitiveC *OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::ExpFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ExpFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ExpFusion>();

primitive_c->set_base(-1.0);
primitive_c->set_scale(1.0);
primitive_c->set_shift(0.0);
prim->set_base(-1.0);
prim->set_scale(1.0);
prim->set_shift(0.0);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Cos;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Cos failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Cos>();
return prim.release();
}

ops::PrimitiveC *OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Ceil;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Ceil failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Ceil>();
return prim.release();
}

ops::PrimitiveC *OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Log;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Log failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Log>();
return prim.release();
}

ops::PrimitiveC *OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Atan;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Atan failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Atan>();
return prim.release();
}

ops::PrimitiveC *OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Asin;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Asin failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Asin>();
return prim.release();
}

ops::PrimitiveC *OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::LogicalAnd;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LogicalAnd failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::LogicalAnd>();
return prim.release();
}

ops::PrimitiveC *OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::LogicalOr;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LogicalOr failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::LogicalOr>();
return prim.release();
}

ops::PrimitiveC *OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::LogicalNot;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LogicalNot failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::LogicalNot>();
return prim.release();
}

ops::PrimitiveC *OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Neg;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Neg failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Neg>();
return prim.release();
}

ops::PrimitiveC *OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Round;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Round failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Round>();
return prim.release();
}

ops::PrimitiveC *OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Sin;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new sin failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Sin>();
return prim.release();
}

ops::PrimitiveC *OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Tan;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Tan failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Tan>();
return prim.release();
}

ops::PrimitiveC *OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Sqrt;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Sqrt failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Sqrt>();
return prim.release();
}

ops::PrimitiveC *OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::PowFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new PowFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::PowFusion>();

primitive_c->set_scale(1.0);
primitive_c->set_shift(0.0);
prim->set_scale(1.0);
prim->set_shift(0.0);

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Minimum;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Minimum failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Minimum>();
return prim.release();
}

ops::PrimitiveC *OnnxMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Maximum;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Maximum failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Maximum>();
return prim.release();
}

ops::PrimitiveC *OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Eltwise;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Eltwise failed";
return nullptr;
}
auto prim = std::make_unique<ops::Eltwise>();

if (onnx_node.op_type() == "Sum") {
primitive_c->set_mode(mindspore::EltwiseMode::SUM);
prim->set_mode(mindspore::EltwiseMode::SUM);
} else {
MS_LOG(ERROR) << "unsupported Eltwise type";
return nullptr;
}

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxReciprocalParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Reciprocal;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Reciprocal failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Reciprocal>();
return prim.release();
}

OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc View File

@@ -21,21 +21,17 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::FusedBatchNorm;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new FusedBatchNorm failed";
return nullptr;
}
auto prim = std::make_unique<ops::FusedBatchNorm>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "epsilon") {
primitive_c->set_epsilon(onnx_node_attr.f());
prim->set_epsilon(onnx_node_attr.f());
} else if (onnx_node_attr.name() == "momentum") {
primitive_c->set_momentum(onnx_node_attr.f());
prim->set_momentum(onnx_node_attr.f());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser());


+ 2
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc View File

@@ -21,13 +21,8 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::BiasAdd;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new BiasAdd failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::BiasAdd>();
return prim.release();
}

OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc View File

@@ -23,11 +23,7 @@ namespace mindspore {
namespace lite {

ops::PrimitiveC *OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Cast;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Cast failed";
return nullptr;
}
auto prim = std::make_unique<ops::Cast>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@@ -36,11 +32,11 @@ ops::PrimitiveC *OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const
if (dst_type == kNumberTypeInt64) {
dst_type = kNumberTypeInt32;
}
primitive_c->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type)));
prim->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type)));
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser());


+ 6
- 10
mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc View File

@@ -21,24 +21,20 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Clip;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Clip failed";
return nullptr;
}
auto prim = std::make_unique<ops::Clip>();

primitive_c->set_min(-1);
primitive_c->set_max(-1);
prim->set_min(-1);
prim->set_max(-1);
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "max") {
primitive_c->set_max(onnx_node_attr.f());
prim->set_max(onnx_node_attr.f());
} else if (attribute_name == "min") {
primitive_c->set_min(onnx_node_attr.f());
prim->set_min(onnx_node_attr.f());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Concat;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Concat failed";
return nullptr;
}
auto prim = std::make_unique<ops::Concat>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
primitive_c->set_axis(onnx_node_attr.i());
prim->set_axis(onnx_node_attr.i());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc View File

@@ -24,11 +24,7 @@ namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::ConstantOfShape;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ConstantOfShape failed";
return nullptr;
}
auto prim = std::make_unique<ops::ConstantOfShape>();

int data_type = 0;
std::vector<float> values;
@@ -61,10 +57,10 @@ ops::PrimitiveC *OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_g
if (values.empty()) {
values = {0};
}
primitive_c->set_value(values);
primitive_c->set_data_type((int64_t)data_type);
prim->set_value(values);
prim->set_data_type((int64_t)data_type);

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser());


+ 5
- 11
mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc View File

@@ -24,7 +24,7 @@

namespace mindspore {
namespace lite {
STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *primitive_c) {
STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *prim) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
@@ -48,16 +48,12 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t
MS_LOG(ERROR) << "get value failed.";
return RET_ERROR;
}
primitive_c->set_attr("const_data", param_value);
prim->set_attr("const_data", param_value);
return RET_OK;
}

ops::PrimitiveC *OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Constant;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Constant failed";
return nullptr;
}
auto prim = std::make_unique<ops::Constant>();

for (const auto &attr : onnx_node.attribute()) {
if (attr.name() == "sparse_value") {
@@ -66,18 +62,16 @@ ops::PrimitiveC *OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, c
}
if (attr.name() == "value") {
const auto &const_tensor = attr.t();
if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) {
if (AddDataInfoAttr(const_tensor, prim.get()) != RET_OK) {
MS_LOG(ERROR) << "add basic attr failed.";
delete primitive_c;
return nullptr;
}
} else {
MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented";
delete primitive_c;
return nullptr;
}
}
return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser());


+ 1
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h View File

@@ -29,7 +29,7 @@ class OnnxConstantParser : public OnnxNodeParser {

ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;

STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *primitive_c);
STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *prim);
};
} // namespace lite
} // namespace mindspore


+ 30
- 34
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc View File

@@ -20,22 +20,17 @@
#include <vector>
#include <string>
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/depthwise_conv2d_fusion.h"

namespace mindspore::lite {
ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Conv2DFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Conv2DFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::Conv2DFusion>();

primitive_c->set_pad({0, 0, 0, 0});
prim->set_pad({0, 0, 0, 0});
mindspore::Format format = mindspore::Format::NCHW;
mindspore::PadMode padMode = mindspore::PadMode::PAD;
mindspore::PadMode pad_mode = mindspore::PadMode::PAD;

int64_t channelOut = 1;
int64_t channelIn = 1;
int64_t channel_out = 1;
int64_t channel_in = 1;
int64_t group = 1;
std::vector<int64_t> kernels;
std::vector<int64_t> strides;
@@ -59,7 +54,7 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
}
kernels.push_back(onnx_node_attr.ints(0));
kernels.push_back(onnx_node_attr.ints(1));
primitive_c->set_kernel_size(kernels);
prim->set_kernel_size(kernels);
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
@@ -67,14 +62,14 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
}
kernels.push_back(onnx_node_attr.ints(0));
kernels.push_back(onnx_node_attr.ints(1));
primitive_c->set_kernel_size(kernels);
prim->set_kernel_size(kernels);
} else if (onnx_node_attr.name() == "auto_pad") {
if (onnx_node_attr.s() == "SAME_UPPER") {
padMode = mindspore::PadMode::SAME;
pad_mode = mindspore::PadMode::SAME;
} else if (onnx_node_attr.s() == "VALID") {
padMode = mindspore::PadMode::VALID;
pad_mode = mindspore::PadMode::VALID;
} else if (onnx_node_attr.s() == "NOTSET") {
padMode = mindspore::PadMode::PAD;
pad_mode = mindspore::PadMode::PAD;
} else if (onnx_node_attr.s() == "SAME_LOWER") {
MS_LOG(ERROR) << "unsupported padMode";
return nullptr;
@@ -88,7 +83,7 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
pads.push_back(onnx_node_attr.ints(2));
pads.push_back(onnx_node_attr.ints(1));
pads.push_back(onnx_node_attr.ints(3));
primitive_c->set_pad_list(pads);
prim->set_pad_list(pads);
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
@@ -96,7 +91,7 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
}
strides.push_back(onnx_node_attr.ints(0));
strides.push_back(onnx_node_attr.ints(1));
primitive_c->set_stride(strides);
prim->set_stride(strides);
} else if (onnx_node_attr.name() == "order") {
if (onnx_node_attr.s() == "NHWC") {
format = mindspore::Format::NHWC;
@@ -109,18 +104,18 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
if (dilation.empty()) {
dilation = {1, 1};
}
primitive_c->set_dilation(dilation);
prim->set_dilation(dilation);

if (pads.empty()) {
pads = {0, 0, 0, 0};
}
primitive_c->set_pad_list(pads);
prim->set_pad_list(pads);

primitive_c->set_format(format);
primitive_c->set_pad_mode(padMode);
primitive_c->set_group(group);
prim->set_format(format);
prim->set_pad_mode(pad_mode);
prim->set_group(group);

// get channelOut and channelIn
// get channel_out and channel_in
const auto &onnx_conv_weight = onnx_node.input(1);
if (onnx_node.op_type() == "Conv") {
auto node_iter =
@@ -135,8 +130,8 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*node_iter).dims(i));
}
channelOut = weight_shape[0];
channelIn = weight_shape[1] * group;
channel_out = weight_shape[0];
channel_in = weight_shape[1] * group;
}
} else {
auto node_iter =
@@ -156,23 +151,24 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
}
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
}
channelOut = dims.at(0);
channelIn = dims.at(3) * group;
channel_out = dims.at(0);
channel_in = dims.at(3) * group;
}
primitive_c->set_in_channel(channelIn);
primitive_c->set_out_channel(channelOut);
prim->set_in_channel(channel_in);
prim->set_out_channel(channel_out);

// parse activationType
if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") {
primitive_c->set_activation_type(mindspore::ActivationType::RELU);
prim->set_activation_type(mindspore::ActivationType::RELU);
} else {
primitive_c->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);
prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);
}
if (group == channelIn && channelIn == channelOut) {
primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));

if (group == channel_in && channel_in == channel_out) {
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser());


+ 19
- 20
mindspore/lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc View File

@@ -23,15 +23,12 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Conv2dTransposeFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Conv2dTransposeFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::Conv2dTransposeFusion>();

primitive_c->set_pad({0, 0, 0, 0});
prim->set_pad({0, 0, 0, 0});
mindspore::Format format = mindspore::Format::NCHW;
mindspore::PadMode padMode = mindspore::PadMode::PAD;
mindspore::PadMode pad_mode = mindspore::PadMode::PAD;

int64_t group = 1;
std::vector<int64_t> kernel;
std::vector<int64_t> dilate;
@@ -47,7 +44,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
}
dilate.push_back(onnx_node_attr.ints(0));
dilate.push_back(onnx_node_attr.ints(1));
primitive_c->set_dilation(dilate);
prim->set_dilation(dilate);
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
@@ -55,7 +52,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
}
kernel.push_back(onnx_node_attr.ints(0));
kernel.push_back(onnx_node_attr.ints(1));
primitive_c->set_kernel_size(kernel);
prim->set_kernel_size(kernel);
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
@@ -63,9 +60,9 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
}
kernel.push_back(onnx_node_attr.ints(0));
kernel.push_back(onnx_node_attr.ints(1));
primitive_c->set_kernel_size(kernel);
prim->set_kernel_size(kernel);
} else if (onnx_node_attr.name() == "auto_pad") {
padMode = GetOnnxPadMode(onnx_node_attr);
pad_mode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
@@ -75,7 +72,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
pads.push_back(onnx_node_attr.ints(2));
pads.push_back(onnx_node_attr.ints(1));
pads.push_back(onnx_node_attr.ints(3));
primitive_c->set_pad_list(pads);
prim->set_pad_list(pads);
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
@@ -83,7 +80,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
}
stride.push_back(onnx_node_attr.ints(0));
stride.push_back(onnx_node_attr.ints(1));
primitive_c->set_stride(stride);
prim->set_stride(stride);
} else if (onnx_node_attr.name() == "order") {
if (onnx_node_attr.s() == "NHWC") {
format = mindspore::Format::NHWC;
@@ -96,9 +93,9 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
return nullptr;
}
}
primitive_c->set_format(format);
primitive_c->set_group(group);
primitive_c->set_pad_mode(padMode);
prim->set_format(format);
prim->set_group(group);
prim->set_pad_mode(pad_mode);

const auto &onnx_conv_weight = onnx_node.input(1);
auto node_iter =
@@ -118,12 +115,14 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size();
return nullptr;
}
primitive_c->set_in_channel(weight_shape[0]);
primitive_c->set_out_channel(weight_shape[1] * group);
prim->set_in_channel(weight_shape[0]);
prim->set_out_channel(weight_shape[1] * group);

if (group != 1 && weight_shape[1] == 1) {
primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
}
return primitive_c;

return prim.release();
}

OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::DepthToSpace;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new DepthToSpace failed";
return nullptr;
}
auto prim = std::make_unique<ops::DepthToSpace>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {
primitive_c->set_block_size(onnx_node_attr.i());
prim->set_block_size(onnx_node_attr.i());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Dropout;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Dropout failed";
return nullptr;
}
auto prim = std::make_unique<ops::Dropout>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "ratio") {
primitive_c->set_keep_prob(onnx_node_attr.f());
prim->set_keep_prob(onnx_node_attr.f());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::BroadcastTo;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new BroadcastTo failed";
return nullptr;
}
auto prim = std::make_unique<ops::BroadcastTo>();

std::vector<int64_t> dst_shape;
const auto &onnx_expand_power = onnx_node.input(1);
@@ -46,9 +42,9 @@ ops::PrimitiveC *OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, con
}
}
}
primitive_c->set_shape(dst_shape);
prim->set_shape(dst_shape);

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser());


+ 2
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc View File

@@ -21,13 +21,8 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Flatten;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Flatten failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Flatten>();
return prim.release();
}

OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Gather;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Gather failed";
return nullptr;
}
auto prim = std::make_unique<ops::Gather>();

int32_t axis = 0;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@@ -34,9 +30,9 @@ ops::PrimitiveC *OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, con
axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
primitive_c->AddAttr("axis", MakeValue(axis));
prim->AddAttr("axis", MakeValue(axis));

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::MakeTuple;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed";
return nullptr;
}
auto prim = std::make_unique<ops::MakeTuple>();

auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul");
if (node_parser == nullptr) {
@@ -34,7 +30,7 @@ ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const
return nullptr;
}
auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node);
primitive_c->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive));
prim->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive));

node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd");
if (node_parser == nullptr) {
@@ -42,9 +38,9 @@ ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const
return nullptr;
}
auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node);
primitive_c->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive));
prim->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive));

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser());


+ 11
- 22
mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc View File

@@ -24,14 +24,10 @@

namespace mindspore {
namespace lite {
STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node,
ops::PrimitiveC *primitive_c,
STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim,
const std::vector<int> &shape) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
return RET_ERROR;
}

int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "values"; });
@@ -46,6 +42,7 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodePr
}
if (iter->ints().data() == nullptr) {
MS_LOG(ERROR) << "origin ints data in onnx is nullptr";
delete[] param_data;
return RET_NULL_PTR;
}
if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) {
@@ -57,18 +54,14 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodePr
param_value->set_format(schema::Format_NUM_OF_FORMAT);
param_value->set_tensor_type(kNumberTypeInt64);
param_value->SetTensorData(param_data, data_size);
primitive_c->set_attr("const_data", param_value);
prim->set_attr("const_data", param_value);
return RET_OK;
}

STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node,
ops::PrimitiveC *primitive_c,
STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim,
const std::vector<int> &shape) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
return RET_ERROR;
}

int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "values"; });
@@ -89,16 +82,12 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto
param_value->set_format(schema::Format_NUM_OF_FORMAT);
param_value->set_tensor_type(kNumberTypeUInt8);
param_value->SetTensorData(param_data, data_count);
primitive_c->set_attr("const_data", param_value);
prim->set_attr("const_data", param_value);
return RET_OK;
}
ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Constant;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Constant failed";
return nullptr;
}
auto prim = std::make_unique<ops::Constant>();

std::vector<int64_t> shape_vector;
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
@@ -110,18 +99,18 @@ ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_g
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
[](const int64_t &val) { return static_cast<int32_t>(val); });
if (onnx_node.op_type() == "Int8GivenIntTensorFill") {
if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) {
if (ParseInt8GivenIntTensorFill(onnx_node, prim.get(), shape) != RET_OK) {
MS_LOG(ERROR) << "given tensor fill parse failed.";
return nullptr;
}
} else if (onnx_node.op_type() == "Int8GivenTensorFill") {
if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) {
if (ParseInt8GivenTensorFill(onnx_node, prim.get(), shape) != RET_OK) {
MS_LOG(ERROR) << "given tensor fill parse failed.";
return nullptr;
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser());


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h View File

@@ -30,9 +30,9 @@ class OnnxGivenTensorFillParser : public OnnxNodeParser {

ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;

STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c,
STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim,
const std::vector<int> &shape);
STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c,
STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim,
const std::vector<int> &shape);
};
} // namespace lite


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc View File

@@ -21,14 +21,10 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Identity;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Identity failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Identity>();
return prim.release();
}

OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser());
} // namespace lite
} // namespace mindspore

+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc View File

@@ -21,22 +21,18 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::LayerNormFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LayerNormFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::LayerNormFusion>();

primitive_c->set_elementwise_affine(true);
prim->set_elementwise_affine(true);

if (!onnx_node.attribute().empty()) {
auto onnx_node_attr = onnx_node.attribute().at(0);
if (onnx_node_attr.name() == "epsilon") {
primitive_c->set_epsilon(onnx_node_attr.f());
prim->set_epsilon(onnx_node_attr.f());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc View File

@@ -21,22 +21,18 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::LpNormalization;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LpNormalization failed";
return nullptr;
}
auto prim = std::make_unique<ops::LpNormalization>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
primitive_c->set_axis(onnx_node_attr.i());
prim->set_axis(onnx_node_attr.i());
} else if (attribute_name == "p") {
primitive_c->set_p(onnx_node_attr.i());
prim->set_p(onnx_node_attr.i());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser());


+ 6
- 10
mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Lrn;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LRN failed";
return nullptr;
}
auto prim = std::make_unique<ops::Lrn>();

int64_t size = 0;
float alpha = 0;
@@ -34,12 +30,12 @@ ops::PrimitiveC *OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const
if (attribute_name == "alpha") {
alpha = onnx_node_attr.f();
} else if (attribute_name == "beta") {
primitive_c->set_beta(onnx_node_attr.f());
prim->set_beta(onnx_node_attr.f());
} else if (attribute_name == "bias") {
primitive_c->set_bias(onnx_node_attr.f());
prim->set_bias(onnx_node_attr.f());
} else if (attribute_name == "size") {
size = onnx_node_attr.i();
primitive_c->set_depth_radius(size / 2);
prim->set_depth_radius(size / 2);
}
}

@@ -48,9 +44,9 @@ ops::PrimitiveC *OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const
return nullptr;
}
alpha /= size;
primitive_c->set_alpha(alpha);
prim->set_alpha(alpha);

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser());


+ 8
- 12
mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc View File

@@ -21,32 +21,28 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::LSTM;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LSTM failed";
return nullptr;
}
auto prim = std::make_unique<ops::LSTM>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "direction") {
const auto &direction = onnx_node_attr.s();
bool bidirectional = direction == "bidirectional";
primitive_c->set_bidirectional(bidirectional);
prim->set_bidirectional(bidirectional);
if (bidirectional) {
primitive_c->set_num_directions(2);
prim->set_num_directions(2);
} else {
primitive_c->set_num_directions(1);
prim->set_num_directions(1);
}
} else if (onnx_node_attr.name() == "hidden_size") {
primitive_c->set_hidden_size(onnx_node_attr.i());
prim->set_hidden_size(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "clip") {
primitive_c->set_dropout(onnx_node_attr.f());
prim->set_dropout(onnx_node_attr.f());
} else if (onnx_node_attr.name() == "activations") {
primitive_c->set_has_bias(true);
prim->set_has_bias(true);
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::MatMul;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MatMul failed";
return nullptr;
}
auto prim = std::make_unique<ops::MatMul>();

float alpha = 1.0f;
float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "transA") {
primitive_c->set_transpose_a(static_cast<bool>(onnx_node_attr.i()));
prim->set_transpose_a(static_cast<bool>(onnx_node_attr.i()));
} else if (attribute_name == "transB") {
primitive_c->set_transpose_b(static_cast<bool>(onnx_node_attr.i()));
prim->set_transpose_b(static_cast<bool>(onnx_node_attr.i()));
} else if (attribute_name == "alpha") {
alpha = onnx_node_attr.f();
} else if (attribute_name == "beta") {
@@ -46,7 +42,7 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con
return nullptr;
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser());


+ 30
- 45
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -48,10 +48,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
const QuantType &quant_type) {
NoSupportOp::GetInstance()->SetFmkType("ONNX");
func_graph_ptr_ = std::make_shared<FuncGraph>();
if (func_graph_ptr_ == nullptr) {
MS_LOG(ERROR) << "funcgraph is nullptr.";
return nullptr;
}

auto status = InitOriginModel(model_file);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -164,25 +161,24 @@ STATUS OnnxModelParser::ConvertNodes() {
if (status != RET_OK) {
continue;
}
auto primitive_c = node_parser->Parse(onnx_graph_, onnx_node);
auto prim = node_parser->Parse(onnx_graph_, onnx_node);
MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
if (primitive_c == nullptr) {
if (prim == nullptr) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
status = RET_ERROR;
continue;
}
status = ConvertOpQuantParams(onnx_node, primitive_c);
if (status != RET_OK) {
if (ConvertOpQuantParams(onnx_node, prim) != RET_OK) {
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
continue;
}
if (IsSpecialOnnxNode(onnx_node)) {
auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c);
auto status_node = ConvertSpecialOnnxNode(onnx_node, prim);
status = status == RET_OK ? status_node : status;
continue;
}
// build CNode
status = BuildCNode(onnx_node, primitive_c);
status = BuildCNode(onnx_node, prim);
if (status != RET_OK) {
MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed.";
}
@@ -195,10 +191,7 @@ STATUS OnnxModelParser::ConvertGraphOutputs() {
if (onnx_graph_.output_size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new return nullptr";
return RET_NULL_PTR;
}

for (const auto &graph_out : onnx_graph_.output()) {
if (nodes_.find(graph_out.name()) == nodes_.end()) {
MS_LOG(ERROR) << "graph output get failed.";
@@ -236,19 +229,15 @@ STATUS OnnxModelParser::ConvertGraphOutputs() {

STATUS OnnxModelParser::BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs) {
auto returnPrim = std::make_shared<ops::Return>();
if (returnPrim == nullptr) {
MS_LOG(ERROR) << "new return nullptr";
return RET_NULL_PTR;
}
auto returnCnode = func_graph_ptr_->NewCNode(returnPrim, return_inputs);
returnCnode->set_fullname_with_scope("return");
func_graph_ptr_->set_return(returnCnode);
return RET_OK;
}

STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr.";
STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) {
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr.";
return RET_NULL_PTR;
}
std::vector<AnfNodePtr> op_inputs;
@@ -263,7 +252,7 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, ops::Primit
op_inputs.push_back(nodes_[input_name]);
}
}
auto new_cnode = func_graph_ptr_->NewCNode(std::shared_ptr<ops::PrimitiveC>(primitive_c), op_inputs);
auto new_cnode = func_graph_ptr_->NewCNode(std::shared_ptr<ops::PrimitiveC>(prim), op_inputs);
new_cnode->set_fullname_with_scope(onnx_node.op_type() + "_" + onnx_node.output(0));
auto status = BuildOpOutputs(onnx_node, new_cnode);
return status;
@@ -287,10 +276,6 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const C
auto type_ptr = TypeIdToType(kTypeUnknown);
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new return nullptr";
return RET_NULL_PTR;
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
@@ -304,9 +289,9 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const C
return RET_OK;
}

STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) {
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is null, get quant params failed.";
return RET_NULL_PTR;
}
auto status = ParseQuantParam(onnx_node);
@@ -337,7 +322,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
}
quant_params_holder->AddOutputQuantParam(quant_params);
}
primitive_c->AddAttr("quant_params", quant_params_holder);
prim->AddAttr("quant_params", quant_params_holder);
return RET_OK;
}

@@ -462,8 +447,8 @@ STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, Qua
return RET_OK;
}

STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr) {
STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) {
if (prim == nullptr) {
MS_LOG(ERROR) << "imitive_c is nullptr.";
return RET_NULL_PTR;
}
@@ -472,30 +457,30 @@ STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node,
MS_LOG(ERROR) << "loop hasn't supported.";
return RET_NOT_FIND_OP;
} else if (onnx_node.op_type() == "Gemm") {
status = ConvertOnnxGemmNode(onnx_node, primitive_c);
status = ConvertOnnxGemmNode(onnx_node, prim);
} else {
MS_LOG(ERROR) << "the node is not special node.";
status = RET_ERROR;
}
delete primitive_c;
delete prim;
return status;
}

STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) {
STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) {
if (onnx_node.op_type() != "Gemm") {
MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type();
return RET_ERROR;
}
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr.";
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr.";
return RET_NULL_PTR;
}
auto status = BuildCNodeForGemm(onnx_node, primitive_c, "MatMul");
auto status = BuildCNodeForGemm(onnx_node, prim, "MatMul");
if (status != RET_OK) {
MS_LOG(ERROR) << "convert gemm node failed.";
return status;
}
status = BuildCNodeForGemm(onnx_node, primitive_c, "BiasAdd");
status = BuildCNodeForGemm(onnx_node, prim, "BiasAdd");
if (status != RET_OK) {
MS_LOG(ERROR) << "convert gemm node failed.";
return status;
@@ -503,14 +488,14 @@ STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, op
return RET_OK;
}

STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c,
STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim,
const std::string &name) {
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr.";
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr.";
return RET_NULL_PTR;
}
auto value = primitive_c->GetAttr(name);
primitive_c->EraseAttr(name);
auto value = prim->GetAttr(name);
prim->EraseAttr(name);
if (value == nullptr) {
MS_LOG(ERROR) << "op parse failed.";
return RET_NULL_PTR;
@@ -524,7 +509,7 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops:
std::vector<int64_t> shape_vector;
std::vector<AnfNodePtr> op_inputs;
auto quant_params_holder = std::make_shared<QuantParamHolder>();
auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
auto quant_params_holder_origin = prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
if (name == "MatMul") {
for (int i = 0; i < 2; ++i) {
if (nodes_.find(onnx_node.input(i)) == nodes_.end()) {


+ 5
- 5
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -55,12 +55,12 @@ class OnnxModelParser : public ModelParser {
STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs);
STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor);
STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim);
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c, const std::string &name);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, const std::string &name);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim);
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc View File

@@ -22,22 +22,18 @@ namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::NonMaxSuppression;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new NonMaxSuppression failed";
return nullptr;
}
auto prim = std::make_unique<ops::NonMaxSuppression>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "center_point_box") {
if (onnx_node_attr.has_i()) {
primitive_c->set_center_point_box(onnx_node_attr.i());
prim->set_center_point_box(onnx_node_attr.i());
}
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::OneHot;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new OneHot failed";
return nullptr;
}
auto prim = std::make_unique<ops::OneHot>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
primitive_c->set_axis(onnx_node_attr.i());
prim->set_axis(onnx_node_attr.i());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser());


+ 9
- 13
mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc View File

@@ -22,13 +22,9 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::PadFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new PadFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::PadFusion>();

mindspore::PaddingMode paddingMode;
mindspore::PaddingMode padding_mode;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "pads") {
@@ -40,28 +36,28 @@ ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const
paddings[i][0] = static_cast<int64_t>(onnx_node_attr.ints(i));
paddings[i][1] = static_cast<int64_t>(onnx_node_attr.ints(i + size / 2));
}
primitive_c->set_paddings(paddings);
prim->set_paddings(paddings);

std::vector<std::vector<int32_t>> pads(size / 2, std::vector<int32_t>(2, 0));
for (int i = 0; i < size / 2; i++) {
pads[i][0] = static_cast<int32_t>(onnx_node_attr.ints(i));
pads[i][1] = static_cast<int32_t>(onnx_node_attr.ints(i + size / 2));
}
primitive_c->AddAttr("pads", MakeValue(pads));
prim->AddAttr("pads", MakeValue(pads));
} else if (attribute_name == "mode") {
const auto &mode = onnx_node_attr.s();
if (mode == "constant") {
paddingMode = mindspore::PaddingMode::CONSTANT;
padding_mode = mindspore::PaddingMode::CONSTANT;
} else if (mode == "reflect") {
paddingMode = mindspore::PaddingMode::REFLECT;
padding_mode = mindspore::PaddingMode::REFLECT;
} else if (mode == "edge") {
paddingMode = mindspore::PaddingMode::SYMMETRIC;
padding_mode = mindspore::PaddingMode::SYMMETRIC;
}
primitive_c->set_padding_mode(paddingMode);
prim->set_padding_mode(padding_mode);
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser());


+ 21
- 29
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc View File

@@ -23,14 +23,10 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::AvgPoolFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new AvgPoolFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::AvgPoolFusion>();

primitive_c->set_format(mindspore::Format::NCHW);
primitive_c->set_pad_mode(mindspore::PadMode::PAD);
prim->set_format(mindspore::Format::NCHW);
prim->set_pad_mode(mindspore::PadMode::PAD);
mindspore::RoundMode roundMode = mindspore::RoundMode::FLOOR;
std::vector<int64_t> kernels;
std::vector<int64_t> strides;
@@ -41,7 +37,7 @@ ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
if (onnx_node_attr.ints_size() == 2) {
kernels.push_back(onnx_node_attr.ints(0));
kernels.push_back(onnx_node_attr.ints(1));
primitive_c->set_kernel_size(kernels);
prim->set_kernel_size(kernels);
}
}
if (attribute_name == "strides") {
@@ -52,7 +48,7 @@ ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
}
if (attribute_name == "auto_pad") {
if (onnx_node_attr.s() == "SAME_UPPER") {
primitive_c->set_pad_mode(mindspore::PadMode::SAME);
prim->set_pad_mode(mindspore::PadMode::SAME);
} else if (onnx_node_attr.s() == "SAME_LOWER") {
MS_LOG(ERROR) << "PadMode_SAME_LOWER is not supported now";
return nullptr;
@@ -78,34 +74,30 @@ ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
return nullptr;
}
}
primitive_c->set_round_mode(roundMode);
prim->set_round_mode(roundMode);

if (strides.empty()) {
strides.push_back(1);
strides.push_back(1);
}
primitive_c->set_strides(strides);
prim->set_strides(strides);
if (pads.empty()) {
pads = {0, 0, 0, 0};
}
primitive_c->set_pad(pads);
prim->set_pad(pads);
if (onnx_node.op_type() == "GlobalAveragePool") {
primitive_c->set_global(true);
prim->set_global(true);
} else {
primitive_c->set_global(false);
prim->set_global(false);
}

return primitive_c;
return prim.release();
}

ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::MaxPoolFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MaxPoolFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::MaxPoolFusion>();

primitive_c->set_format(mindspore::Format::NCHW);
prim->set_format(mindspore::Format::NCHW);
mindspore::RoundMode roundMode = mindspore::RoundMode::FLOOR;
std::vector<int64_t> kernels;
std::vector<int64_t> strides;
@@ -116,7 +108,7 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
if (onnx_node_attr.ints_size() == 2) {
kernels.push_back(onnx_node_attr.ints(0));
kernels.push_back(onnx_node_attr.ints(1));
primitive_c->set_kernel_size(kernels);
prim->set_kernel_size(kernels);
}
}
if (attribute_name == "strides") {
@@ -127,7 +119,7 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
}
if (attribute_name == "auto_pad") {
if (onnx_node_attr.s() == "SAME_UPPER") {
primitive_c->set_pad_mode(mindspore::PadMode::SAME);
prim->set_pad_mode(mindspore::PadMode::SAME);
} else if (onnx_node_attr.s() == "SAME_LOWER") {
MS_LOG(ERROR) << "PadMode_SAME_LOWER is not supported now";
return nullptr;
@@ -135,7 +127,7 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
}
if (attribute_name == "pads") {
if (onnx_node_attr.ints_size() == 4) {
primitive_c->set_pad_mode(mindspore::PadMode::PAD);
prim->set_pad_mode(mindspore::PadMode::PAD);
pads.push_back(onnx_node_attr.ints(0));
pads.push_back(onnx_node_attr.ints(2));
pads.push_back(onnx_node_attr.ints(1));
@@ -154,22 +146,22 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co
return nullptr;
}
}
primitive_c->set_round_mode(roundMode);
prim->set_round_mode(roundMode);

if (pads.empty()) {
pads = {0, 0, 0, 0};
}
primitive_c->set_pad(pads);
prim->set_pad(pads);

if (strides.empty()) {
strides.push_back(1);
strides.push_back(1);
}
primitive_c->set_strides(strides);
prim->set_strides(strides);

primitive_c->set_global(onnx_node.op_type() == "GlobalMaxPool");
prim->set_global(onnx_node.op_type() == "GlobalMaxPool");

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxAveragePoolParser("AveragePool", new OnnxAvgPoolParser());


+ 6
- 10
mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc View File

@@ -21,24 +21,20 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::QuantDTypeCast;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new QuantDTypeCast failed";
return nullptr;
}
auto prim = std::make_unique<ops::QuantDTypeCast>();

if (onnx_node.op_type() == "Int8Quantize") {
primitive_c->set_src_t(kNumberTypeFloat32);
primitive_c->set_dst_t(kNumberTypeUInt8);
prim->set_src_t(kNumberTypeFloat32);
prim->set_dst_t(kNumberTypeUInt8);
} else if (onnx_node.op_type() == "Int8Dequantize") {
primitive_c->set_src_t(kNumberTypeUInt8);
primitive_c->set_dst_t(kNumberTypeFloat32);
prim->set_src_t(kNumberTypeUInt8);
prim->set_dst_t(kNumberTypeFloat32);
} else {
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
return nullptr;
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.cc View File

@@ -21,15 +21,11 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Range;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Range failed";
return nullptr;
}
auto prim = std::make_unique<ops::Range>();

primitive_c->set_d_type(0);
prim->set_d_type(0);

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser());


+ 11
- 15
mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc View File

@@ -22,13 +22,9 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::ReduceFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ReduceFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ReduceFusion>();

primitive_c->set_keep_dims(true);
prim->set_keep_dims(true);
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {
@@ -37,30 +33,30 @@ ops::PrimitiveC *OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, con
for (int i = 0; i < size; ++i) {
axes.push_back(onnx_node_attr.ints(i));
}
primitive_c->AddAttr("axes", MakeValue(axes));
prim->AddAttr("axes", MakeValue(axes));
} else if (attribute_name == "keepdims") {
primitive_c->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
}
}
const auto &type = onnx_node.op_type();
if (type == "ReduceMean") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Mean);
prim->set_mode(mindspore::ReduceMode::Reduce_Mean);
} else if (type == "ReduceMax") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Max);
prim->set_mode(mindspore::ReduceMode::Reduce_Max);
} else if (type == "ReduceMin") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Min);
prim->set_mode(mindspore::ReduceMode::Reduce_Min);
} else if (type == "ReduceSum") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum);
prim->set_mode(mindspore::ReduceMode::Reduce_Sum);
} else if (type == "ReduceProd") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Prod);
prim->set_mode(mindspore::ReduceMode::Reduce_Prod);
} else if (type == "ReduceSumSquare") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum_Square);
prim->set_mode(mindspore::ReduceMode::Reduce_Sum_Square);
} else {
MS_LOG(ERROR) << "unsupported reduce type: " << type;
return nullptr;
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Reshape;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Reshape failed";
return nullptr;
}
auto prim = std::make_unique<ops::Reshape>();

std::vector<int32_t> shape;
shape.clear();
@@ -37,12 +33,12 @@ ops::PrimitiveC *OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, co
for (int i = 0; i < onnx_node_attr.ints_size(); ++i) {
shape.push_back(static_cast<int>(onnx_node_attr.ints(i)));
}
primitive_c->AddAttr("shape", MakeValue(shape));
prim->AddAttr("shape", MakeValue(shape));
}
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser());


+ 11
- 14
mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc View File

@@ -34,14 +34,11 @@ ops::PrimitiveC *OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, con
}

// use bilinear method
auto primitive_c = new (std::nothrow) ops::Resize;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Resize failed";
return nullptr;
}
auto prim = std::make_unique<ops::Resize>();

prim->set_format(mindspore::Format::NCHW);
prim->set_nearest_mode(mindspore::NearestMode::ROUND_HALF_DOWN);

primitive_c->set_format(mindspore::Format::NCHW);
primitive_c->set_nearest_mode(mindspore::NearestMode::ROUND_HALF_DOWN);
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "coordinate_transformation_mode") {
@@ -51,24 +48,24 @@ ops::PrimitiveC *OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, con
{"align_corners", mindspore::CoordinateTransformMode::ALIGN_CORNERS},
{"asymmetric", mindspore::CoordinateTransformMode::ASYMMETRIC}};
if (transform_map.find(onnx_node_attr.s()) != transform_map.end()) {
primitive_c->set_coordinate_transform_mode(transform_map[onnx_node_attr.s()]);
prim->set_coordinate_transform_mode(transform_map[onnx_node_attr.s()]);
} else {
MS_LOG(ERROR) << "Unsupport coordinate transform mode: " << attribute_name;
return nullptr;
}
} else if (attribute_name == "cubic_coeff_a") {
primitive_c->set_cubic_coeff(onnx_node_attr.f());
prim->set_cubic_coeff(onnx_node_attr.f());
} else if (attribute_name == "exclude_outside") {
primitive_c->set_exclude_outside(onnx_node_attr.i());
prim->set_exclude_outside(onnx_node_attr.i());
} else if (attribute_name == "extrapolation_value") {
primitive_c->set_extrapolation_value(onnx_node_attr.f());
prim->set_extrapolation_value(onnx_node_attr.f());
} else if (attribute_name == "mode") {
std::map<std::string, mindspore::ResizeMethod> resize_mode = {
{"nearest", mindspore::ResizeMethod::NEAREST},
{"linear", mindspore::ResizeMethod::LINEAR},
{"cubic", mindspore::ResizeMethod::CUBIC},
};
primitive_c->set_method(resize_mode[onnx_node_attr.s()]);
prim->set_method(resize_mode[onnx_node_attr.s()]);
} else if (attribute_name == "nearest_mode") {
std::map<std::string, mindspore::NearestMode> nearest_mode = {
{"round_prefer_floor", mindspore::NearestMode::ROUND_HALF_DOWN},
@@ -76,11 +73,11 @@ ops::PrimitiveC *OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, con
{"floor", mindspore::NearestMode::FLOOR},
{"ceil", mindspore::NearestMode::CEIL},
};
primitive_c->set_nearest_mode(nearest_mode[onnx_node_attr.s()]);
prim->set_nearest_mode(nearest_mode[onnx_node_attr.s()]);
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser());


+ 2
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc View File

@@ -21,13 +21,8 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Shape;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Shape failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::Shape>();
return prim.release();
}

OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser());


+ 8
- 11
mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc View File

@@ -26,11 +26,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::StridedSlice;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new StridedSlice failed";
return nullptr;
}
auto prim = std::make_unique<ops::StridedSlice>();

std::vector<int32_t> starts;
std::vector<int32_t> ends;
@@ -76,7 +72,7 @@ ops::PrimitiveC *OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, cons
size = static_cast<int>(steps.size());
}
if (size == -1) {
return primitive_c;
return prim.release();
}
if (axes.empty()) {
for (size_t i = 0; i < starts.size(); ++i) {
@@ -87,11 +83,12 @@ ops::PrimitiveC *OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, cons
steps.assign(starts.size(), 1);
}

primitive_c->AddAttr("starts", MakeValue(starts));
primitive_c->AddAttr("axes", MakeValue(axes));
primitive_c->AddAttr("ends", MakeValue(ends));
primitive_c->AddAttr("steps", MakeValue(steps));
return primitive_c;
prim->AddAttr("starts", MakeValue(starts));
prim->AddAttr("axes", MakeValue(axes));
prim->AddAttr("ends", MakeValue(ends));
prim->AddAttr("steps", MakeValue(steps));

return prim.release();
}

OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc View File

@@ -21,11 +21,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Softmax;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new SoftMax failed";
return nullptr;
}
auto prim = std::make_unique<ops::Softmax>();

int64_t axis;
bool axis_is_def = true;
@@ -39,9 +35,9 @@ ops::PrimitiveC *OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, co
if (axis_is_def) {
axis = OnnxNodeParser::opset_version() >= 13 ? -1 : 1;
}
primitive_c->set_axis({axis});
prim->set_axis({axis});

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::SpaceToDepth;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new SpaceToDepth failed";
return nullptr;
}
auto prim = std::make_unique<ops::SpaceToDepth>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {
primitive_c->set_block_size(onnx_node_attr.i());
prim->set_block_size(onnx_node_attr.i());
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser());


+ 7
- 10
mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc View File

@@ -23,31 +23,28 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Split;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Split failed";
return nullptr;
}
auto prim = std::make_unique<ops::Split>();

primitive_c->set_axis(0);
prim->set_axis(0);
std::vector<int64_t> size_splits;
int64_t split_num = 0;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
primitive_c->set_axis(onnx_node_attr.i());
prim->set_axis(onnx_node_attr.i());
} else if (attribute_name == "split") {
size_splits.resize(onnx_node_attr.ints_size());
std::copy(onnx_node_attr.ints().begin(), onnx_node_attr.ints().end(), size_splits.begin());
primitive_c->set_size_splits(size_splits);
prim->set_size_splits(size_splits);
split_num = onnx_node_attr.ints_size();
}
}
if (split_num == 0) {
split_num = onnx_node.output_size();
}
primitive_c->set_output_num(split_num);
return primitive_c;
prim->set_output_num(split_num);

return prim.release();
}

OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Squeeze;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Squeeze failed";
return nullptr;
}
auto prim = std::make_unique<ops::Squeeze>();

std::vector<int64_t> axis;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@@ -35,11 +31,11 @@ ops::PrimitiveC *OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, co
for (int i = 0; i < onnx_node_attr.ints().size(); ++i) {
axis.emplace_back(onnx_node_attr.ints(i));
}
primitive_c->set_axis(axis);
prim->set_axis(axis);
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser());


+ 2
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc View File

@@ -21,13 +21,8 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::TileFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new TileFusion failed";
return nullptr;
}

return primitive_c;
auto prim = std::make_unique<ops::TileFusion>();
return prim.release();
}

OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc View File

@@ -21,20 +21,16 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::TopKFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new TopKFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::TopKFusion>();

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "k") {
primitive_c->AddAttr("k", MakeValue(static_cast<int32_t>(onnx_node_attr.i())));
prim->AddAttr("k", MakeValue(static_cast<int32_t>(onnx_node_attr.i())));
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Transpose;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Transpose failed";
return nullptr;
}
auto prim = std::make_unique<ops::Transpose>();

std::vector<int32_t> perm;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@@ -36,11 +32,11 @@ ops::PrimitiveC *OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
for (int i = 0; i < onnx_node_attr.ints_size(); ++i) {
perm[i] = onnx_node_attr.ints(i);
}
primitive_c->AddAttr("perm", MakeValue(perm));
prim->AddAttr("perm", MakeValue(perm));
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc View File

@@ -22,11 +22,7 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto primitive_c = new (std::nothrow) ops::Unsqueeze;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Unsqueeze failed";
return nullptr;
}
auto prim = std::make_unique<ops::Unsqueeze>();

std::vector<int64_t> axis;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@@ -35,11 +31,11 @@ ops::PrimitiveC *OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
for (int i = 0; i < onnx_node_attr.ints().size(); ++i) {
axis.emplace_back(onnx_node_attr.ints(i));
}
primitive_c->set_axis(axis);
prim->set_axis(axis);
}
}

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxUnsqueezeParser("Unsqueeze", new OnnxUnSqueezeParser());


+ 7
- 11
mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc View File

@@ -23,14 +23,10 @@
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
// use bilinear method
auto primitive_c = new (std::nothrow) ops::Resize;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Resize failed";
return nullptr;
}
auto prim = std::make_unique<ops::Resize>();

prim->set_method(mindspore::ResizeMethod::NEAREST); // use bilinear method

primitive_c->set_method(mindspore::ResizeMethod::NEAREST);
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") {
@@ -38,13 +34,13 @@ ops::PrimitiveC *OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, c
MS_LOG(ERROR) << "the UpSample mode don't support now.";
return nullptr;
}
primitive_c->set_method(onnx_node_attr.s() == "nearest" ? mindspore::ResizeMethod::NEAREST
: mindspore::ResizeMethod::LINEAR);
prim->set_method(onnx_node_attr.s() == "nearest" ? mindspore::ResizeMethod::NEAREST
: mindspore::ResizeMethod::LINEAR);
}
}
primitive_c->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC);
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC);

return primitive_c;
return prim.release();
}

OnnxNodeRegistrar g_onnxUpsampleParser("Upsample", new OnnxUpsampleParser());


+ 7
- 10
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc View File

@@ -25,20 +25,16 @@ namespace lite {
ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Activation();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Activation failed";
return nullptr;
}
auto prim = std::make_unique<ops::Activation>();

if (tf_op.op() == "Relu") {
primitive_c->set_activation_type(mindspore::ActivationType::RELU);
prim->set_activation_type(mindspore::ActivationType::RELU);
} else if (tf_op.op() == "Relu6") {
primitive_c->set_activation_type(mindspore::ActivationType::RELU6);
prim->set_activation_type(mindspore::ActivationType::RELU6);
} else if (tf_op.op() == "Sigmoid") {
primitive_c->set_activation_type(mindspore::ActivationType::SIGMOID);
prim->set_activation_type(mindspore::ActivationType::SIGMOID);
} else if (tf_op.op() == "Tanh") {
primitive_c->set_activation_type(mindspore::ActivationType::TANH);
prim->set_activation_type(mindspore::ActivationType::TANH);
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
return nullptr;
@@ -49,7 +45,8 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}
return primitive_c;

return prim.release();
}

TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());


+ 24
- 72
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc View File

@@ -44,89 +44,41 @@ ops::PrimitiveC *TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
}

if (tf_op.op() == "Add" || tf_op.op() == "AddV2") {
auto primitive_c = new (std::nothrow) ops::AddFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new AddFusion failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::AddFusion>();
return prim.release();
} else if (tf_op.op() == "Sub") {
auto primitive_c = new (std::nothrow) ops::SubFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new SubFusion failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::SubFusion>();
return prim.release();
} else if (tf_op.op() == "Mul") {
auto primitive_c = new (std::nothrow) ops::MulFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MulFusion failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::MulFusion>();
return prim.release();
} else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") {
auto primitive_c = new (std::nothrow) ops::DivFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new DivFusion failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::DivFusion>();
return prim.release();
} else if (tf_op.op() == "Maximum") {
auto primitive_c = new (std::nothrow) ops::Maximum;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Maximum failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::Maximum>();
return prim.release();
} else if (tf_op.op() == "Minimum") {
auto primitive_c = new (std::nothrow) ops::Minimum;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Minimum failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::Minimum>();
return prim.release();
} else if (tf_op.op() == "Greater") {
auto primitive_c = new (std::nothrow) ops::Greater;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Greater failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::Greater>();
return prim.release();
} else if (tf_op.op() == "GreaterEqual") {
auto primitive_c = new (std::nothrow) ops::GreaterEqual;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new GreaterEqual failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::GreaterEqual>();
return prim.release();
} else if (tf_op.op() == "Less") {
auto primitive_c = new (std::nothrow) ops::Less;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Less failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::Less>();
return prim.release();
} else if (tf_op.op() == "LessEqual") {
auto primitive_c = new (std::nothrow) ops::LessEqual;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LessEqual failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::LessEqual>();
return prim.release();
} else if (tf_op.op() == "Equal") {
auto primitive_c = new (std::nothrow) ops::Equal;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Equal failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::Equal>();
return prim.release();
} else if (tf_op.op() == "NotEqual") {
auto primitive_c = new (std::nothrow) ops::NotEqual;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new NotEqual failed";
return nullptr;
}
return primitive_c;
auto prim = std::make_unique<ops::NotEqual>();
return prim.release();
}
return nullptr;
}


+ 3
- 7
mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc View File

@@ -26,18 +26,14 @@ namespace lite {
ops::PrimitiveC *TFAssertParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Assert;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "New Assert failed";
return nullptr;
}
auto prim = std::make_unique<ops::Assert>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) {
MS_LOG(ERROR) << "The keep_dims attr should be specified";
return nullptr;
}
primitive_c->set_summarize((int64_t)(attr_value.i()));
prim->set_summarize((int64_t)(attr_value.i()));

*output_size = 0; // Assert not have output
for (int i = 0; i < tf_op.input_size(); ++i) {
@@ -47,7 +43,7 @@ ops::PrimitiveC *TFAssertParser::Parse(const tensorflow::NodeDef &tf_op,
}
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::BiasAdd;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new BiasAdd failed";
return nullptr;
}
auto prim = std::make_unique<ops::BiasAdd>();

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -38,7 +34,7 @@ ops::PrimitiveC *TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc View File

@@ -26,18 +26,14 @@ namespace lite {
ops::PrimitiveC *TFCastParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Cast;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Cast failed";
return nullptr;
}
auto prim = std::make_unique<ops::Cast>();

auto dst_type = TensorFlowUtils::ParseAttrDataType(tf_op, "DstT");
if (dst_type == kTypeUnknown) {
MS_LOG(ERROR) << "Get attr DstT failed";
return nullptr;
}
primitive_c->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type)));
prim->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type)));

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
@@ -45,7 +41,7 @@ ops::PrimitiveC *TFCastParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfCastParser("Cast", new TFCastParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFConcatParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Concat;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Concat failed";
return nullptr;
}
auto prim = std::make_unique<ops::Concat>();

auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(tf_op.input_size() - 1));
if (axis_node == nullptr) {
@@ -43,7 +39,7 @@ ops::PrimitiveC *TFConcatParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}
auto tensor_proto = attr_value.tensor();
primitive_c->set_axis(tensor_proto.int_val(0));
prim->set_axis(tensor_proto.int_val(0));

*output_size = 1;
for (int i = 0; i < tf_op.input_size() - 1; ++i) {
@@ -53,7 +49,7 @@ ops::PrimitiveC *TFConcatParser::Parse(const tensorflow::NodeDef &tf_op,
}
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfConcatV2Parser("ConcatV2", new TFConcatParser());


+ 12
- 16
mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc View File

@@ -27,14 +27,10 @@ namespace lite {
ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Conv2DFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Conv2DFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::Conv2DFusion>();

primitive_c->set_pad({0, 0, 0, 0});
primitive_c->set_group(1);
prim->set_pad({0, 0, 0, 0});
prim->set_group(1);

// parse format
auto format = TensorFlowUtils::ParseNodeFormat(tf_op);
@@ -42,7 +38,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now";
return nullptr;
}
primitive_c->set_format(format);
prim->set_format(format);

// parse kernel
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
@@ -55,9 +51,9 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "parse kernels failed";
return nullptr;
}
primitive_c->set_kernel_size({kernels[0], kernels[1]});
primitive_c->set_out_channel(kernels[3]);
primitive_c->set_in_channel(kernels[2]);
prim->set_kernel_size({kernels[0], kernels[1]});
prim->set_out_channel(kernels[3]);
prim->set_in_channel(kernels[2]);

// parse stride
std::vector<int64_t> strides(2);
@@ -65,7 +61,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "parse strides failed";
return nullptr;
}
primitive_c->set_stride(strides);
prim->set_stride(strides);

// parse dilation
std::vector<int64_t> dilations(2);
@@ -73,11 +69,11 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "parse dilations failed";
return nullptr;
}
primitive_c->set_dilation(dilations);
prim->set_dilation(dilations);

// parse pad
auto padMode = ParsePadMode(tf_op);
primitive_c->set_pad_mode(padMode);
auto pad_mode = ParsePadMode(tf_op);
prim->set_pad_mode(pad_mode);

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -85,7 +81,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfConvParser("Conv2D", new TFConvParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::ExpandDims;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ExpandDims failed";
return nullptr;
}
auto prim = std::make_unique<ops::ExpandDims>();

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -38,7 +34,7 @@ ops::PrimitiveC *TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfExpandDimsParser("ExpandDims", new TFExpandDimsParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Gather;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Gather failed";
return nullptr;
}
auto prim = std::make_unique<ops::Gather>();

int batchDims = 0;
tensorflow::AttrValue attr_value;
@@ -72,7 +68,7 @@ ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
if (batchDims != 0 && !axis_is_set) {
axis = batchDims;
}
primitive_c->AddAttr("axis", MakeValue(axis));
prim->AddAttr("axis", MakeValue(axis));

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -80,7 +76,7 @@ ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfGatherV2Parser("Gather", new TFGatherParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc View File

@@ -27,16 +27,12 @@ ops::PrimitiveC *TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
if (tf_op.op() == "LogicalAnd") {
auto primitive_c = new (std::nothrow) ops::LogicalAnd;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new LogicalAnd failed";
return nullptr;
}
auto prim = std::make_unique<ops::LogicalAnd>();
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return primitive_c;
return prim.release();
} else {
MS_LOG(ERROR) << "only LogicalAnd is supported now";
return nullptr;


+ 4
- 8
mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc View File

@@ -26,18 +26,14 @@ namespace lite {
ops::PrimitiveC *TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::MatMul;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new MatMul failed";
return nullptr;
}
auto prim = std::make_unique<ops::MatMul>();

tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) {
primitive_c->set_transpose_a(attr_value.b());
prim->set_transpose_a(attr_value.b());
}
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) {
primitive_c->set_transpose_b(attr_value.b());
prim->set_transpose_b(attr_value.b());
}

*output_size = 1;
@@ -46,7 +42,7 @@ ops::PrimitiveC *TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/tf/tf_pack_parser.cc View File

@@ -26,18 +26,14 @@ namespace lite {
ops::PrimitiveC *TFPackParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Stack;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Stack failed";
return nullptr;
}
auto prim = std::make_unique<ops::Stack>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) {
MS_LOG(ERROR) << "The axis attr should be specified";
return nullptr;
}
primitive_c->set_axis({attr_value.i()});
prim->set_axis({attr_value.i()});

*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
@@ -47,7 +43,7 @@ ops::PrimitiveC *TFPackParser::Parse(const tensorflow::NodeDef &tf_op,
}
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfPackParser("Pack", new TFPackParser());


+ 5
- 9
mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.cc View File

@@ -32,30 +32,26 @@ ops::PrimitiveC *TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

auto primitive = new (std::nothrow) ops::Range;
if (primitive == nullptr) {
MS_LOG(ERROR) << "New RaggedRange failed";
return nullptr;
}
auto prim = std::make_unique<ops::Range>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) {
MS_LOG(ERROR) << "The starts attr should be specified";
return nullptr;
}
primitive->set_start(static_cast<int64_t>(attr_value.i()));
prim->set_start(static_cast<int64_t>(attr_value.i()));

if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) {
MS_LOG(ERROR) << "The limits attr should be specified";
return nullptr;
}
primitive->set_limit(static_cast<int64_t>(attr_value.i()));
prim->set_limit(static_cast<int64_t>(attr_value.i()));

if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) {
MS_LOG(ERROR) << "The deltas attr should be specified";
return nullptr;
}
primitive->set_delta(static_cast<int64_t>(attr_value.i()));
prim->set_delta(static_cast<int64_t>(attr_value.i()));

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
@@ -63,7 +59,7 @@ ops::PrimitiveC *TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "add op input is failed!";
return nullptr;
}
return primitive;
return prim.release();
}

TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser());


+ 6
- 9
mindspore/lite/tools/converter/parser/tf/tf_range_parser.cc View File

@@ -33,23 +33,19 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

auto primitive_c = new (std::nothrow) ops::Range;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "New Range failed";
return nullptr;
}
auto prim = std::make_unique<ops::Range>();

tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) {
primitive_c->set_start(static_cast<int64_t>(attr_value.i()));
prim->set_start(static_cast<int64_t>(attr_value.i()));
}

if (TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) {
primitive_c->set_limit(static_cast<int64_t>(attr_value.i()));
prim->set_limit(static_cast<int64_t>(attr_value.i()));
}

if (TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) {
primitive_c->set_delta(static_cast<int64_t>(attr_value.i()));
prim->set_delta(static_cast<int64_t>(attr_value.i()));
}

*output_size = 1;
@@ -60,7 +56,8 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "add op input failed!";
return nullptr;
}
return primitive_c;

return prim.release();
}

TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser());


+ 9
- 13
mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc View File

@@ -26,24 +26,20 @@ namespace lite {
ops::PrimitiveC *TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::ReduceFusion;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new ReduceFusion failed";
return nullptr;
}
auto prim = std::make_unique<ops::ReduceFusion>();

if (tf_op.op() == "Sum") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum);
prim->set_mode(mindspore::ReduceMode::Reduce_Sum);
} else if (tf_op.op() == "Max") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Max);
prim->set_mode(mindspore::ReduceMode::Reduce_Max);
} else if (tf_op.op() == "Min") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Min);
prim->set_mode(mindspore::ReduceMode::Reduce_Min);
} else if (tf_op.op() == "Mean") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Mean);
prim->set_mode(mindspore::ReduceMode::Reduce_Mean);
} else if (tf_op.op() == "Prod") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_Prod);
prim->set_mode(mindspore::ReduceMode::Reduce_Prod);
} else if (tf_op.op() == "All") {
primitive_c->set_mode(mindspore::ReduceMode::Reduce_All);
prim->set_mode(mindspore::ReduceMode::Reduce_All);
} else {
MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op();
return nullptr;
@@ -59,7 +55,7 @@ ops::PrimitiveC *TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "the keep_dims attr of reduce should be bool type";
return nullptr;
}
primitive_c->set_keep_dims(attr_value.b());
prim->set_keep_dims(attr_value.b());

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -67,7 +63,7 @@ ops::PrimitiveC *TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfSumParser("Sum", new TFReduceParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Reshape;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Reshape failed";
return nullptr;
}
auto prim = std::make_unique<ops::Reshape>();

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -38,7 +34,7 @@ ops::PrimitiveC *TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfReshapeParser("Reshape", new TFReshapeParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc View File

@@ -32,23 +32,19 @@ ops::PrimitiveC *TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op
MS_LOG(ERROR) << "primitiveC is nullptr";
return nullptr;
}
auto primitive = new (std::nothrow) ops::ReverseSequence;
if (primitive == nullptr) {
MS_LOG(ERROR) << "New ReverseSequenceParser failed";
return nullptr;
}
auto prim = std::make_unique<ops::ReverseSequence>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "batch_dim", &attr_value)) {
MS_LOG(ERROR) << "The batch_dim attr should be specified";
return nullptr;
}
primitive->set_batch_dim(attr_value.i());
prim->set_batch_dim(attr_value.i());
if (!TensorFlowUtils::FindAttrValue(tf_op, "seq_dim", &attr_value)) {
MS_LOG(ERROR) << "The seq_dim attr should be specified";
return nullptr;
}
primitive->set_seq_dim(attr_value.i());
prim->set_seq_dim(attr_value.i());

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
@@ -56,7 +52,7 @@ ops::PrimitiveC *TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op
return nullptr;
}

return primitive;
return prim.release();
}

TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/tf/tf_round_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFRoundParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Round;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Round failed";
return nullptr;
}
auto prim = std::make_unique<ops::Round>();

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
@@ -38,7 +34,7 @@ ops::PrimitiveC *TFRoundParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfRoundParser("Round", new TFRoundParser());


+ 2
- 6
mindspore/lite/tools/converter/parser/tf/tf_shape_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFShapeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Shape;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Shape failed";
return nullptr;
}
auto prim = std::make_unique<ops::Shape>();

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
@@ -38,7 +34,7 @@ ops::PrimitiveC *TFShapeParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfShapeParser("Shape", new TFShapeParser());


+ 7
- 11
mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc View File

@@ -26,19 +26,15 @@ namespace lite {
ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Split;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Split failed";
return nullptr;
}
auto prim = std::make_unique<ops::Split>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "num_split", &attr_value)) {
MS_LOG(ERROR) << "The attribute num_split should be specified";
return nullptr;
}
auto numberSplit = attr_value.i();
primitive_c->set_output_num(numberSplit);
auto number_split = attr_value.i();
prim->set_output_num(number_split);

int split_dim_index = 2;
int input_index = 0;
@@ -57,7 +53,7 @@ ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}
auto splitDim = attr_value.tensor().int_val(0);
primitive_c->set_axis(splitDim);
prim->set_axis(splitDim);

if (tf_op.op() == "SplitV") {
auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1));
@@ -80,16 +76,16 @@ ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "memcpy_s failed";
return nullptr;
}
primitive_c->set_size_splits(sizeSplits);
prim->set_size_splits(sizeSplits);
}

*output_size = numberSplit;
*output_size = number_split;
if (AddOpInput(tf_op, input_index, inputs) != RET_OK) {
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfSplitParser("Split", new TFSplitParser());


+ 3
- 7
mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.cc View File

@@ -27,11 +27,7 @@ namespace lite {
ops::PrimitiveC *TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::Squeeze;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new Squeeze failed";
return nullptr;
}
auto prim = std::make_unique<ops::Squeeze>();

std::vector<int64_t> axis;
tensorflow::AttrValue attr_value;
@@ -43,7 +39,7 @@ ops::PrimitiveC *TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op,
for (int i = 0; i < dims.i_size(); ++i) {
axis.push_back(dims.i(i));
}
primitive_c->set_axis(axis);
prim->set_axis(axis);

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
@@ -51,7 +47,7 @@ ops::PrimitiveC *TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfSqueezeParser("Squeeze", new TFSqueezeParser());


+ 7
- 11
mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc View File

@@ -27,42 +27,38 @@ namespace lite {
ops::PrimitiveC *TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive_c = new (std::nothrow) ops::StridedSlice;
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "new StridedSlice failed";
return nullptr;
}
auto prim = std::make_unique<ops::StridedSlice>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "begin_mask", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return nullptr;
}
primitive_c->set_begin_mask(attr_value.i());
prim->set_begin_mask(attr_value.i());

if (!TensorFlowUtils::FindAttrValue(tf_op, "end_mask", &attr_value)) {
MS_LOG(ERROR) << "The end_mask attr should be specified";
return nullptr;
}
primitive_c->set_end_mask(attr_value.i());
prim->set_end_mask(attr_value.i());

if (!TensorFlowUtils::FindAttrValue(tf_op, "ellipsis_mask", &attr_value)) {
MS_LOG(ERROR) << "The ellipsis_mask attr should be specified";
return nullptr;
}
primitive_c->set_ellipsis_mask(attr_value.i());
prim->set_ellipsis_mask(attr_value.i());

if (!TensorFlowUtils::FindAttrValue(tf_op, "new_axis_mask", &attr_value)) {
MS_LOG(ERROR) << "The new_axis_mask attr should be specified";
return nullptr;
}
primitive_c->set_new_axis_mask(attr_value.i());
prim->set_new_axis_mask(attr_value.i());

if (!TensorFlowUtils::FindAttrValue(tf_op, "shrink_axis_mask", &attr_value)) {
MS_LOG(ERROR) << "The shrink_axis_mask attr should be specified";
return nullptr;
}
primitive_c->set_shrink_axis_mask(attr_value.i());
prim->set_shrink_axis_mask(attr_value.i());

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK ||
@@ -71,7 +67,7 @@ ops::PrimitiveC *TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
return nullptr;
}

return primitive_c;
return prim.release();
}

TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser());


+ 4
- 8
mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.cc View File

@@ -26,11 +26,7 @@ namespace lite {
ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto primitive = new (std::nothrow) ops::TensorListFromTensor;
if (primitive == nullptr) {
MS_LOG(ERROR) << "New TensorListFromTensor failed";
return nullptr;
}
auto prim = std::make_unique<ops::TensorListFromTensor>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) {
@@ -42,7 +38,7 @@ ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &
MS_LOG(ERROR) << "tensor_list_from_tensor element dtype must be known type";
return nullptr;
}
primitive->set_element_dtype((int64_t)(type));
prim->set_element_dtype((int64_t)(type));

if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) {
MS_LOG(ERROR) << "The shape_type attr should be specified";
@@ -53,7 +49,7 @@ ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &
MS_LOG(ERROR) << "tensor_list_from_tensor shape type must be known type";
return nullptr;
}
primitive->set_shape_type((int64_t)(type));
prim->set_shape_type((int64_t)(type));

*output_size = 1;
for (int i = 0; i < 2; ++i) {
@@ -63,7 +59,7 @@ ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &
}
}

return primitive;
return prim.release();
}

TFNodeRegistrar g_tfTensorListFromTensorParser("TensorListFromTensor", new TFTensorListFromTensorParser());


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save