Browse Source

modify static check

pull/15967/head
liuyu 4 years ago
parent
commit
ddfc7f3c02
100 changed files with 1430 additions and 1286 deletions
  1. +13
    -11
      mindspore/lite/src/ops/populate/activation_grad_populate.cc
  2. +17
    -17
      mindspore/lite/src/ops/populate/activation_populate.cc
  3. +4
    -2
      mindspore/lite/src/ops/populate/adam_populate.cc
  4. +10
    -9
      mindspore/lite/src/ops/populate/add_populate.cc
  5. +35
    -32
      mindspore/lite/src/ops/populate/adder_populate.cc
  6. +10
    -9
      mindspore/lite/src/ops/populate/addn_populate.cc
  7. +1
    -2
      mindspore/lite/src/ops/populate/argmax_populate.cc
  8. +16
    -16
      mindspore/lite/src/ops/populate/argmin_populate.cc
  9. +4
    -2
      mindspore/lite/src/ops/populate/arithmetic_populate.cc
  10. +9
    -7
      mindspore/lite/src/ops/populate/arithmetic_self_populate.cc
  11. +9
    -8
      mindspore/lite/src/ops/populate/assert_populate.cc
  12. +4
    -2
      mindspore/lite/src/ops/populate/assign_add_populate.cc
  13. +3
    -2
      mindspore/lite/src/ops/populate/assign_populate.cc
  14. +14
    -12
      mindspore/lite/src/ops/populate/audio_spectrogram_populate.cc
  15. +15
    -15
      mindspore/lite/src/ops/populate/batch_norm_populate.cc
  16. +25
    -21
      mindspore/lite/src/ops/populate/batch_to_space_populate.cc
  17. +9
    -9
      mindspore/lite/src/ops/populate/bias_add_populate.cc
  18. +8
    -9
      mindspore/lite/src/ops/populate/bias_grad_populate.cc
  19. +13
    -14
      mindspore/lite/src/ops/populate/binary_cross_entropy_grad_populate.cc
  20. +11
    -10
      mindspore/lite/src/ops/populate/binary_cross_entropy_populate.cc
  21. +14
    -11
      mindspore/lite/src/ops/populate/broadcast_to_populate.cc
  22. +10
    -6
      mindspore/lite/src/ops/populate/call_populate.cc
  23. +4
    -4
      mindspore/lite/src/ops/populate/cast_populate.cc
  24. +9
    -9
      mindspore/lite/src/ops/populate/clip_populate.cc
  25. +9
    -9
      mindspore/lite/src/ops/populate/common_populate.cc
  26. +13
    -13
      mindspore/lite/src/ops/populate/concat_populate.cc
  27. +22
    -18
      mindspore/lite/src/ops/populate/constant_of_shape_populate.cc
  28. +44
    -42
      mindspore/lite/src/ops/populate/conv2d_populate.cc
  29. +15
    -14
      mindspore/lite/src/ops/populate/crop_and_resize_populate.cc
  30. +21
    -18
      mindspore/lite/src/ops/populate/crop_populate.cc
  31. +15
    -10
      mindspore/lite/src/ops/populate/cumsum_populate.cc
  32. +6
    -5
      mindspore/lite/src/ops/populate/custom_extract_features_populate.cc
  33. +6
    -3
      mindspore/lite/src/ops/populate/custom_normalize_populate.cc
  34. +9
    -7
      mindspore/lite/src/ops/populate/custom_predict_populate.cc
  35. +46
    -43
      mindspore/lite/src/ops/populate/deconv2d_populate.cc
  36. +4
    -2
      mindspore/lite/src/ops/populate/default_populate.cc
  37. +13
    -11
      mindspore/lite/src/ops/populate/depth_to_space_populate.cc
  38. +5
    -4
      mindspore/lite/src/ops/populate/depthwise_conv2d_populate.cc
  39. +26
    -26
      mindspore/lite/src/ops/populate/detection_post_process_populate.cc
  40. +2
    -1
      mindspore/lite/src/ops/populate/div_populate.cc
  41. +10
    -10
      mindspore/lite/src/ops/populate/eltwise_populate.cc
  42. +14
    -13
      mindspore/lite/src/ops/populate/elu_populate.cc
  43. +8
    -8
      mindspore/lite/src/ops/populate/embedding_lookup_populate.cc
  44. +16
    -15
      mindspore/lite/src/ops/populate/exp_populate.cc
  45. +9
    -9
      mindspore/lite/src/ops/populate/expand_dims_populate.cc
  46. +9
    -9
      mindspore/lite/src/ops/populate/fill_populate.cc
  47. +8
    -7
      mindspore/lite/src/ops/populate/flatten_populate.cc
  48. +23
    -23
      mindspore/lite/src/ops/populate/full_connection_populate.cc
  49. +13
    -11
      mindspore/lite/src/ops/populate/fused_batchnorm_populate.cc
  50. +9
    -9
      mindspore/lite/src/ops/populate/gather_nd_populate.cc
  51. +8
    -9
      mindspore/lite/src/ops/populate/gather_populate.cc
  52. +13
    -14
      mindspore/lite/src/ops/populate/gru_populate.cc
  53. +5
    -2
      mindspore/lite/src/ops/populate/hashtable_lookup_populate.cc
  54. +11
    -10
      mindspore/lite/src/ops/populate/instance_norm_populate.cc
  55. +17
    -17
      mindspore/lite/src/ops/populate/l2_norm_populate.cc
  56. +14
    -12
      mindspore/lite/src/ops/populate/layer_norm_grad_populate.cc
  57. +17
    -14
      mindspore/lite/src/ops/populate/layer_norm_populate.cc
  58. +14
    -12
      mindspore/lite/src/ops/populate/local_response_normalization_populate.cc
  59. +14
    -14
      mindspore/lite/src/ops/populate/log_softmax_populate.cc
  60. +12
    -11
      mindspore/lite/src/ops/populate/lsh_projection_populate.cc
  61. +15
    -16
      mindspore/lite/src/ops/populate/lstm_populate.cc
  62. +16
    -12
      mindspore/lite/src/ops/populate/matmul_populate.cc
  63. +10
    -8
      mindspore/lite/src/ops/populate/merge_populate.cc
  64. +13
    -13
      mindspore/lite/src/ops/populate/mfcc_populate.cc
  65. +4
    -4
      mindspore/lite/src/ops/populate/mul_populate.cc
  66. +8
    -6
      mindspore/lite/src/ops/populate/non_max_suppression_populate.cc
  67. +12
    -10
      mindspore/lite/src/ops/populate/one_hot_populate.cc
  68. +5
    -3
      mindspore/lite/src/ops/populate/oneslike_populate.cc
  69. +9
    -6
      mindspore/lite/src/ops/populate/p_relu_populate.cc
  70. +13
    -10
      mindspore/lite/src/ops/populate/pad_populate.cc
  71. +11
    -10
      mindspore/lite/src/ops/populate/partial_populate.cc
  72. +83
    -78
      mindspore/lite/src/ops/populate/pooling_populate.cc
  73. +15
    -15
      mindspore/lite/src/ops/populate/power_populate.cc
  74. +33
    -27
      mindspore/lite/src/ops/populate/prior_box_populate.cc
  75. +13
    -11
      mindspore/lite/src/ops/populate/quant_dtype_cast_populate.cc
  76. +15
    -12
      mindspore/lite/src/ops/populate/random_standard_normal_populate.cc
  77. +18
    -17
      mindspore/lite/src/ops/populate/range_populate.cc
  78. +9
    -9
      mindspore/lite/src/ops/populate/rank_populate.cc
  79. +15
    -14
      mindspore/lite/src/ops/populate/reduce_populate.cc
  80. +9
    -9
      mindspore/lite/src/ops/populate/reshape_populate.cc
  81. +15
    -15
      mindspore/lite/src/ops/populate/resize_populate.cc
  82. +13
    -11
      mindspore/lite/src/ops/populate/reverse_populate.cc
  83. +14
    -15
      mindspore/lite/src/ops/populate/reverse_sequence_populate.cc
  84. +15
    -16
      mindspore/lite/src/ops/populate/roi_pooling_populate.cc
  85. +12
    -12
      mindspore/lite/src/ops/populate/scale_populate.cc
  86. +10
    -9
      mindspore/lite/src/ops/populate/scatter_nd_populate.cc
  87. +9
    -9
      mindspore/lite/src/ops/populate/shape_populate.cc
  88. +14
    -11
      mindspore/lite/src/ops/populate/skip_gram_populate.cc
  89. +13
    -9
      mindspore/lite/src/ops/populate/slice_populate.cc
  90. +17
    -16
      mindspore/lite/src/ops/populate/softmax_populate.cc
  91. +27
    -24
      mindspore/lite/src/ops/populate/space_to_batch_nd_populate.cc
  92. +26
    -21
      mindspore/lite/src/ops/populate/space_to_batch_populate.cc
  93. +13
    -10
      mindspore/lite/src/ops/populate/space_to_depth_populate.cc
  94. +10
    -8
      mindspore/lite/src/ops/populate/sparse_softmax_cross_entropy_with_logits.cc
  95. +9
    -9
      mindspore/lite/src/ops/populate/sparse_to_dense_populate.cc
  96. +36
    -32
      mindspore/lite/src/ops/populate/splice_populate.cc
  97. +24
    -24
      mindspore/lite/src/ops/populate/split_populate.cc
  98. +29
    -20
      mindspore/lite/src/ops/populate/split_with_overlap_populate.cc
  99. +19
    -20
      mindspore/lite/src/ops/populate/squeeze_populate.cc
  100. +12
    -11
      mindspore/lite/src/ops/populate/stack_populate.cc

+ 13
- 11
mindspore/lite/src/ops/populate/activation_grad_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_ActivationGrad;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateActivationGradParameter(const void *prim) { OpParameter *PopulateActivationGradParameter(const void *prim) {
auto *act_param = reinterpret_cast<ActivationGradParameter *>(malloc(sizeof(ActivationGradParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ActivationGradParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ActivationGrad(); auto value = primitive->value_as_ActivationGrad();
@@ -34,11 +27,20 @@ OpParameter *PopulateActivationGradParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
act_param->op_parameter.type_ = primitive->value_type();
act_param->type_ = static_cast<int>(value->activation_type());
act_param->alpha_ = value->alpha();
return reinterpret_cast<OpParameter *>(act_param);

auto *param = reinterpret_cast<ActivationGradParameter *>(malloc(sizeof(ActivationGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ActivationGradParameter));

param->op_parameter.type_ = primitive->value_type();
param->type_ = static_cast<int>(value->activation_type());
param->alpha_ = value->alpha();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_ActivationGrad, PopulateActivationGradParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_ActivationGrad, PopulateActivationGradParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 17
- 17
mindspore/lite/src/ops/populate/activation_populate.cc View File

@@ -19,29 +19,29 @@ using mindspore::schema::PrimitiveType_Activation;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateRelu6Parameter(const void *prim) { OpParameter *PopulateRelu6Parameter(const void *prim) {
auto *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ActivationParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
act_param->op_parameter_.type_ = primitive->value_type();
auto acti_prim = primitive->value_as_Activation();
if (acti_prim == nullptr) {
MS_LOG(ERROR) << "acti_prim is nullptr";
auto value = primitive->value_as_Activation();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr; return nullptr;
} }
act_param->type_ = static_cast<int>(acti_prim->activation_type());
act_param->alpha_ = acti_prim->alpha();
act_param->min_val_ = acti_prim->min_val();
act_param->max_val_ = acti_prim->max_val();
return reinterpret_cast<OpParameter *>(act_param);
memset(param, 0, sizeof(ActivationParameter));

param->op_parameter_.type_ = primitive->value_type();
param->type_ = static_cast<int>(value->activation_type());
param->alpha_ = value->alpha();
param->min_val_ = value->min_val();
param->max_val_ = value->max_val();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Activation, PopulateRelu6Parameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Activation, PopulateRelu6Parameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 4
- 2
mindspore/lite/src/ops/populate/adam_populate.cc View File

@@ -19,14 +19,16 @@ using mindspore::schema::PrimitiveType_Adam;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateAdamParameter(const void *prim) { OpParameter *PopulateAdamParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc Adam Parameter failed."; MS_LOG(ERROR) << "malloc Adam Parameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param; return param;
} }


+ 10
- 9
mindspore/lite/src/ops/populate/add_populate.cc View File

@@ -20,24 +20,25 @@ using mindspore::schema::PrimitiveType_AddFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateAddParameter(const void *prim) { OpParameter *PopulateAddParameter(const void *prim) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_AddFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); ArithmeticParameter *param = PopulateArithmeticCommonPara(prim);
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr; return nullptr;
} }
auto *primitive = static_cast<const schema::Primitive *>(prim);
param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
auto add_prim = primitive->value_as_AddFusion();
if (add_prim == nullptr) {
MS_LOG(ERROR) << "add_prim is nullptr";
return nullptr;
}
param->activation_type_ = add_prim->activation_type();
param->activation_type_ = value->activation_type();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
} // namespace

REG_POPULATE(PrimitiveType_AddFusion, PopulateAddParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_AddFusion, PopulateAddParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 35
- 32
mindspore/lite/src/ops/populate/adder_populate.cc View File

@@ -21,56 +21,59 @@ using mindspore::schema::PrimitiveType_AdderFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateAdderParameter(const void *prim) { OpParameter *PopulateAdderParameter(const void *prim) {
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (conv_param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_AdderFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
memset(conv_param, 0, sizeof(ConvParameter));


auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
conv_param->op_parameter_.type_ = primitive->value_type();
auto conv_primitive = primitive->value_as_AdderFusion();
if (conv_primitive == nullptr) {
MS_LOG(ERROR) << "conv_primitive is nullptr";
auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr; return nullptr;
} }
auto kernel_size = conv_primitive->kernel_size();
auto stride = conv_primitive->stride();
auto pad_list = conv_primitive->pad_list();
auto dilation = conv_primitive->dilation();
memset(param, 0, sizeof(ConvParameter));

param->op_parameter_.type_ = primitive->value_type();
auto kernel_size = value->kernel_size();
auto stride = value->stride();
auto pad_list = value->pad_list();
auto dilation = value->dilation();
if (kernel_size == nullptr || stride == nullptr || pad_list == nullptr || dilation == nullptr) { if (kernel_size == nullptr || stride == nullptr || pad_list == nullptr || dilation == nullptr) {
MS_LOG(ERROR) << "nullptr"; MS_LOG(ERROR) << "nullptr";
free(param);
return nullptr; return nullptr;
} }
conv_param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
conv_param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
conv_param->group_ = static_cast<int>(conv_primitive->group());
conv_param->stride_h_ = static_cast<int>(*(stride->begin()));
conv_param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
conv_param->pad_u_ = static_cast<int>(*(pad_list->begin()));
conv_param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
conv_param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2));
conv_param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3));
conv_param->dilation_h_ = static_cast<int>(*(dilation->begin()));
conv_param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
conv_param->input_channel_ = static_cast<int>(conv_primitive->in_channel());
conv_param->output_channel_ = static_cast<int>(conv_primitive->out_channel());
auto act_type = conv_primitive->activation_type();
param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
param->group_ = static_cast<int>(value->group());
param->stride_h_ = static_cast<int>(*(stride->begin()));
param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
param->pad_u_ = static_cast<int>(*(pad_list->begin()));
param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2));
param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3));
param->dilation_h_ = static_cast<int>(*(dilation->begin()));
param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
param->input_channel_ = static_cast<int>(value->in_channel());
param->output_channel_ = static_cast<int>(value->out_channel());
auto act_type = value->activation_type();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
conv_param->act_type_ = ActType_Relu;
param->act_type_ = ActType_Relu;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
conv_param->act_type_ = ActType_Relu6;
param->act_type_ = ActType_Relu6;
break; break;
default: default:
conv_param->act_type_ = ActType_No;
param->act_type_ = ActType_No;
break; break;
} }
return reinterpret_cast<OpParameter *>(conv_param);
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_AdderFusion, PopulateAdderParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_AdderFusion, PopulateAdderParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 10
- 9
mindspore/lite/src/ops/populate/addn_populate.cc View File

@@ -19,20 +19,21 @@ using mindspore::schema::PrimitiveType_AddN;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateAddNParameter(const void *prim) { OpParameter *PopulateAddNParameter(const void *prim) {
auto *addn_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (addn_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed."; MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr; return nullptr;
} }
memset(addn_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
addn_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(addn_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_AddN, PopulateAddNParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_AddN, PopulateAddNParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 1
- 2
mindspore/lite/src/ops/populate/argmax_populate.cc View File

@@ -19,7 +19,6 @@ using mindspore::schema::PrimitiveType_ArgMaxFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateArgMaxParameter(const void *prim) { OpParameter *PopulateArgMaxParameter(const void *prim) {
auto *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); auto *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (arg_param == nullptr) { if (arg_param == nullptr) {
@@ -32,6 +31,7 @@ OpParameter *PopulateArgMaxParameter(const void *prim) {
auto param = primitive->value_as_ArgMaxFusion(); auto param = primitive->value_as_ArgMaxFusion();
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr"; MS_LOG(ERROR) << "param is nullptr";
free(arg_param);
return nullptr; return nullptr;
} }
arg_param->axis_ = param->axis(); arg_param->axis_ = param->axis();
@@ -41,7 +41,6 @@ OpParameter *PopulateArgMaxParameter(const void *prim) {
arg_param->get_max_ = true; arg_param->get_max_ = true;
return reinterpret_cast<OpParameter *>(arg_param); return reinterpret_cast<OpParameter *>(arg_param);
} }
} // namespace


REG_POPULATE(PrimitiveType_ArgMaxFusion, PopulateArgMaxParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ArgMaxFusion, PopulateArgMaxParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 16
- 16
mindspore/lite/src/ops/populate/argmin_populate.cc View File

@@ -19,29 +19,29 @@ using mindspore::schema::PrimitiveType_ArgMinFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateArgMinParameter(const void *prim) { OpParameter *PopulateArgMinParameter(const void *prim) {
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
auto *primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_ArgMinFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
arg_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_ArgMinFusion();

auto *param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
return nullptr; return nullptr;
} }
arg_param->axis_ = param->axis();
arg_param->topk_ = param->top_k();
arg_param->out_value_ = param->out_max_value();
arg_param->keep_dims_ = param->keep_dims();
arg_param->get_max_ = false;
return reinterpret_cast<OpParameter *>(arg_param);
memset(param, 0, sizeof(ArgMinMaxParameter));

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = value->axis();
param->topk_ = value->top_k();
param->out_value_ = value->out_max_value();
param->keep_dims_ = value->keep_dims();
param->get_max_ = false;
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_ArgMinFusion, PopulateArgMinParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ArgMinFusion, PopulateArgMinParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 4
- 2
mindspore/lite/src/ops/populate/arithmetic_populate.cc View File

@@ -36,14 +36,16 @@ using mindspore::schema::PrimitiveType_SquaredDifference;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
ArithmeticParameter *PopulateArithmeticCommonPara(const void *prim) { ArithmeticParameter *PopulateArithmeticCommonPara(const void *prim) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(ArithmeticParameter)); memset(param, 0, sizeof(ArithmeticParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->broadcasting_ = false; param->broadcasting_ = false;
param->ndim_ = 0; param->ndim_ = 0;


+ 9
- 7
mindspore/lite/src/ops/populate/arithmetic_self_populate.cc View File

@@ -35,16 +35,18 @@ using mindspore::schema::PrimitiveType_Square;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateArithmeticSelf(const void *prim) { OpParameter *PopulateArithmeticSelf(const void *prim) {
auto *arithmetic_self_param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (arithmetic_self_param == nullptr) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed."; MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed.";
return nullptr; return nullptr;
} }
memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
arithmetic_self_param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(arithmetic_self_param);
memset(param, 0, sizeof(ArithmeticSelfParameter));

param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_Abs, PopulateArithmeticSelf, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Abs, PopulateArithmeticSelf, SCHEMA_CUR)


+ 9
- 8
mindspore/lite/src/ops/populate/assert_populate.cc View File

@@ -18,20 +18,21 @@ using mindspore::schema::PrimitiveType_Assert;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

OpParameter *PopulateAssertParameter(const void *prim) { OpParameter *PopulateAssertParameter(const void *prim) {
auto *assert_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (assert_parameter == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc AssertParameter failed."; MS_LOG(ERROR) << "malloc AssertParameter failed.";
return nullptr; return nullptr;
} }
memset(assert_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
assert_parameter->type_ = primitive->value_type();
memset(param, 0, sizeof(OpParameter));


return reinterpret_cast<OpParameter *>(assert_parameter);
param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_Assert, PopulateAssertParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Assert, PopulateAssertParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 2
mindspore/lite/src/ops/populate/assign_add_populate.cc View File

@@ -19,14 +19,16 @@ using mindspore::schema::PrimitiveType_AssignAdd;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateAssignAddParameter(const void *prim) { OpParameter *PopulateAssignAddParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc AssignAdd Parameter failed."; MS_LOG(ERROR) << "malloc AssignAdd Parameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param; return param;
} }


+ 3
- 2
mindspore/lite/src/ops/populate/assign_populate.cc View File

@@ -19,6 +19,9 @@ using mindspore::schema::PrimitiveType_Assign;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateAssignParameter(const void *prim) { OpParameter *PopulateAssignParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc Assign Parameter failed."; MS_LOG(ERROR) << "malloc Assign Parameter failed.";
@@ -26,8 +29,6 @@ OpParameter *PopulateAssignParameter(const void *prim) {
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));


auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param; return param;
} }


+ 14
- 12
mindspore/lite/src/ops/populate/audio_spectrogram_populate.cc View File

@@ -21,25 +21,27 @@ namespace mindspore {
namespace lite { namespace lite {
namespace { namespace {
OpParameter *PopulateAudioSpectrogramParameter(const void *prim) { OpParameter *PopulateAudioSpectrogramParameter(const void *prim) {
auto *arg_param = reinterpret_cast<AudioSpectrogramParameter *>(malloc(sizeof(AudioSpectrogramParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc AudioSpectrogramParameter failed.";
auto *primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_AudioSpectrogram();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
memset(arg_param, 0, sizeof(AudioSpectrogramParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
arg_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_AudioSpectrogram();

auto *param = reinterpret_cast<AudioSpectrogramParameter *>(malloc(sizeof(AudioSpectrogramParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc AudioSpectrogramParameter failed.";
return nullptr; return nullptr;
} }
arg_param->window_size_ = param->window_size();
arg_param->stride_ = param->stride();
return reinterpret_cast<OpParameter *>(arg_param);
memset(param, 0, sizeof(AudioSpectrogramParameter));

param->op_parameter_.type_ = primitive->value_type();
param->window_size_ = value->window_size();
param->stride_ = value->stride();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace } // namespace


REG_POPULATE(PrimitiveType_AudioSpectrogram, PopulateAudioSpectrogramParameter, SCHEMA_CUR);
REG_POPULATE(PrimitiveType_AudioSpectrogram, PopulateAudioSpectrogramParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 15
- 15
mindspore/lite/src/ops/populate/batch_norm_populate.cc View File

@@ -19,27 +19,27 @@ using mindspore::schema::PrimitiveType_BatchNorm;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateBatchNorm(const void *prim) { OpParameter *PopulateBatchNorm(const void *prim) {
auto *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (batch_norm_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
return nullptr;
}
memset(batch_norm_param, 0, sizeof(BatchNormParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
batch_norm_param->op_parameter_.type_ = primitive->value_type();
auto prim_batchnorm = primitive->value_as_BatchNorm();
if (prim_batchnorm == nullptr) {
MS_LOG(ERROR) << "prim_batchnorm is nullptr";
auto value = primitive->value_as_BatchNorm();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
return nullptr; return nullptr;
} }
batch_norm_param->epsilon_ = prim_batchnorm->epsilon();
batch_norm_param->fused_ = false;
return reinterpret_cast<OpParameter *>(batch_norm_param);
memset(param, 0, sizeof(BatchNormParameter));

param->op_parameter_.type_ = primitive->value_type();
param->epsilon_ = value->epsilon();
param->fused_ = false;
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_BatchNorm, PopulateBatchNorm, SCHEMA_CUR) REG_POPULATE(PrimitiveType_BatchNorm, PopulateBatchNorm, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 25
- 21
mindspore/lite/src/ops/populate/batch_to_space_populate.cc View File

@@ -20,48 +20,52 @@ using mindspore::schema::PrimitiveType_BatchToSpaceND;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateBatchToSpaceParameter(const void *prim) { OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
auto *batch_space_param = reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
if (batch_space_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
return nullptr;
}
memset(batch_space_param, 0, sizeof(BatchToSpaceParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
batch_space_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_BatchToSpace();
auto value = primitive->value_as_BatchToSpace();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
return nullptr; return nullptr;
} }
auto block_size = param->block_size();
memset(param, 0, sizeof(BatchToSpaceParameter));

param->op_parameter_.type_ = primitive->value_type();
auto block_size = value->block_size();
if (block_size == nullptr) { if (block_size == nullptr) {
return reinterpret_cast<OpParameter *>(batch_space_param);
return reinterpret_cast<OpParameter *>(param);
} }
auto block_shape = std::vector<int64_t>(block_size->begin(), block_size->end()); auto block_shape = std::vector<int64_t>(block_size->begin(), block_size->end());
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
free(batch_space_param);
free(param);
return nullptr; return nullptr;
} }


auto crop = param->crops();
auto crop = value->crops();
if (crop == nullptr) { if (crop == nullptr) {
MS_LOG(ERROR) << "crop is nullptr"; MS_LOG(ERROR) << "crop is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto fb_crops = crop->data(); auto fb_crops = crop->data();
if (fb_crops == nullptr) { if (fb_crops == nullptr) {
MS_LOG(ERROR) << "fb_crops is nullptr"; MS_LOG(ERROR) << "fb_crops is nullptr";
free(param);
return nullptr; return nullptr;
} }
std::vector<int64_t> crops; std::vector<int64_t> crops;
for (auto iter = fb_crops->begin(); iter != fb_crops->end(); ++iter) {
auto crops_data = (*iter)->data();
for (auto fb_crop : *fb_crops) {
auto crops_data = fb_crop->data();
if (crops_data == nullptr) { if (crops_data == nullptr) {
MS_LOG(ERROR) << "crops_data is nullptr"; MS_LOG(ERROR) << "crops_data is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto crops_vec = std::vector<int64_t>(crops_data->begin(), crops_data->end()); auto crops_vec = std::vector<int64_t>(crops_data->begin(), crops_data->end());
@@ -69,20 +73,20 @@ OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
} }
if (crops.size() != COMM_SHAPE_SIZE) { if (crops.size() != COMM_SHAPE_SIZE) {
MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE; MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE;
free(batch_space_param);
free(param);
return nullptr; return nullptr;
} }


for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
batch_space_param->block_shape_[i] = static_cast<int>(block_shape[i]);
param->block_shape_[i] = static_cast<int>(block_shape[i]);
} }


for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
batch_space_param->crops_[i] = static_cast<int>(crops[i]);
param->crops_[i] = static_cast<int>(crops[i]);
} }
return reinterpret_cast<OpParameter *>(batch_space_param);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 9
- 9
mindspore/lite/src/ops/populate/bias_add_populate.cc View File

@@ -19,21 +19,21 @@ using mindspore::schema::PrimitiveType_BiasAdd;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateBiasAddParameter(const void *prim) { OpParameter *PopulateBiasAddParameter(const void *prim) {
auto *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
arithmetic_param->op_parameter_.type_ = primitive->value_type();
memset(param, 0, sizeof(ArithmeticParameter));


return reinterpret_cast<OpParameter *>(arithmetic_param);
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_BiasAdd, PopulateBiasAddParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_BiasAdd, PopulateBiasAddParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 8
- 9
mindspore/lite/src/ops/populate/bias_grad_populate.cc View File

@@ -19,21 +19,20 @@ using mindspore::schema::PrimitiveType_BiasAddGrad;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateBiasAddGradParameter(const void *prim) { OpParameter *PopulateBiasAddGradParameter(const void *prim) {
auto *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
memset(param, 0, sizeof(ArithmeticParameter));


auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
arithmetic_param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(arithmetic_param);
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_BiasAddGrad, PopulateBiasAddGradParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_BiasAddGrad, PopulateBiasAddGradParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite


+ 13
- 14
mindspore/lite/src/ops/populate/binary_cross_entropy_grad_populate.cc View File

@@ -19,27 +19,26 @@ using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) { OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) {
auto *bce_param =
reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
if (bce_param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
return nullptr;
}
memset(bce_param, 0, sizeof(BinaryCrossEntropyGradParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
bce_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_BinaryCrossEntropyGrad();
if (param == nullptr) {
auto value = primitive->value_as_BinaryCrossEntropyGrad();
if (value == nullptr) {
MS_LOG(ERROR) << "param is nullptr"; MS_LOG(ERROR) << "param is nullptr";
return nullptr; return nullptr;
} }
bce_param->reduction = param->reduction();
return reinterpret_cast<OpParameter *>(bce_param);

auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));

param->op_parameter_.type_ = primitive->value_type();
param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite


+ 11
- 10
mindspore/lite/src/ops/populate/binary_cross_entropy_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) { OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) {
BinaryCrossEntropyParameter *bce_param =
reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
if (bce_param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
return nullptr;
}
memset(bce_param, 0, sizeof(BinaryCrossEntropyParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BinaryCrossEntropy(); auto value = primitive->value_as_BinaryCrossEntropy();
@@ -34,9 +27,17 @@ OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
bce_param->op_parameter_.type_ = primitive->value_type();
bce_param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(bce_param);

auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BinaryCrossEntropyParameter));

param->op_parameter_.type_ = primitive->value_type();
param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR);


+ 14
- 11
mindspore/lite/src/ops/populate/broadcast_to_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_BroadcastTo;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateBroadcastToParameter(const void *prim) { OpParameter *PopulateBroadcastToParameter(const void *prim) {
auto *broadcast_param = reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter)));
if (broadcast_param == nullptr) {
MS_LOG(ERROR) << "malloc BroadcastToParameter failed.";
return nullptr;
}
memset(broadcast_param, 0, sizeof(BroadcastToParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BroadcastTo(); auto value = primitive->value_as_BroadcastTo();
@@ -33,17 +27,26 @@ OpParameter *PopulateBroadcastToParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
broadcast_param->op_parameter_.type_ = primitive->value_type();

auto *param = reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BroadcastToParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BroadcastToParameter));

param->op_parameter_.type_ = primitive->value_type();
auto dst_shape = value->shape(); auto dst_shape = value->shape();
if (dst_shape == nullptr) { if (dst_shape == nullptr) {
MS_LOG(ERROR) << "dst_shape is nullptr"; MS_LOG(ERROR) << "dst_shape is nullptr";
free(param);
return nullptr; return nullptr;
} }
broadcast_param->shape_size_ = dst_shape->size();
for (size_t i = 0; i < broadcast_param->shape_size_; ++i) {
broadcast_param->shape_[i] = dst_shape->Get(i);
param->shape_size_ = dst_shape->size();
for (size_t i = 0; i < param->shape_size_; ++i) {
param->shape_[i] = dst_shape->Get(i);
} }
return reinterpret_cast<OpParameter *>(broadcast_param);
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_BroadcastTo, PopulateBroadcastToParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_BroadcastTo, PopulateBroadcastToParameter, SCHEMA_CUR)


+ 10
- 6
mindspore/lite/src/ops/populate/call_populate.cc View File

@@ -19,16 +19,20 @@ using mindspore::schema::PrimitiveType_Call;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateCallParameter(const void *prim) { OpParameter *PopulateCallParameter(const void *prim) {
OpParameter *call_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (call_parameter == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc CallParameter failed."; MS_LOG(ERROR) << "malloc CallParameter failed.";
return nullptr; return nullptr;
} }
memset(call_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
call_parameter->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(call_parameter);
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_Call, PopulateCallParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Call, PopulateCallParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 4
mindspore/lite/src/ops/populate/cast_populate.cc View File

@@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Cast;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateCastParameter(const void *prim) { OpParameter *PopulateCastParameter(const void *prim) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *cast_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *cast_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (cast_param == nullptr) { if (cast_param == nullptr) {
MS_LOG(ERROR) << "malloc CastParameter failed."; MS_LOG(ERROR) << "malloc CastParameter failed.";
return nullptr; return nullptr;
} }
memset(cast_param, 0, sizeof(OpParameter)); memset(cast_param, 0, sizeof(OpParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

cast_param->type_ = primitive->value_type(); cast_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(cast_param); return reinterpret_cast<OpParameter *>(cast_param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Cast, PopulateCastParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Cast, PopulateCastParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 9
- 9
mindspore/lite/src/ops/populate/clip_populate.cc View File

@@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Clip;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateClipParameter(const void *prim) { OpParameter *PopulateClipParameter(const void *prim) {
auto *act_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (act_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ClipParameter failed."; MS_LOG(ERROR) << "malloc ClipParameter failed.";
return nullptr; return nullptr;
} }
memset(act_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
act_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(act_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Clip, PopulateClipParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Clip, PopulateClipParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 9
- 9
mindspore/lite/src/ops/populate/common_populate.cc View File

@@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_ZerosLike;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateCommonParameter(const void *prim) { OpParameter *PopulateCommonParameter(const void *prim) {
auto *common_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (common_parameter == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed."; MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr; return nullptr;
} }
memset(common_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
common_parameter->type_ = primitive->value_type();
return common_parameter;
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return param;
} }
} // namespace


REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR)


+ 13
- 13
mindspore/lite/src/ops/populate/concat_populate.cc View File

@@ -19,26 +19,26 @@ using mindspore::schema::PrimitiveType_Concat;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateConcatParameter(const void *prim) { OpParameter *PopulateConcatParameter(const void *prim) {
auto *concat_param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
if (concat_param == nullptr) {
MS_LOG(ERROR) << "malloc ConcatParameter failed.";
return nullptr;
}
memset(concat_param, 0, sizeof(ConcatParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
concat_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_Concat();
if (param == nullptr) {
auto value = primitive->value_as_Concat();
if (value == nullptr) {
MS_LOG(ERROR) << "param is nullptr"; MS_LOG(ERROR) << "param is nullptr";
return nullptr; return nullptr;
} }
concat_param->axis_ = static_cast<int>(param->axis());
return reinterpret_cast<OpParameter *>(concat_param);

auto *param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ConcatParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ConcatParameter));

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = static_cast<int>(value->axis());
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Concat, PopulateConcatParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Concat, PopulateConcatParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 22
- 18
mindspore/lite/src/ops/populate/constant_of_shape_populate.cc View File

@@ -17,39 +17,42 @@
#include "nnacl/constant_of_shape_parameter.h" #include "nnacl/constant_of_shape_parameter.h"
using mindspore::schema::PrimitiveType_ConstantOfShape; using mindspore::schema::PrimitiveType_ConstantOfShape;


namespace mindspore::lite {
namespace {
namespace mindspore {
namespace lite {
OpParameter *PopulateConstantOfShapeParameter(const void *prim) { OpParameter *PopulateConstantOfShapeParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ConstantOfShape();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<ConstantOfShapeParameter *>(malloc(sizeof(ConstantOfShapeParameter))); auto *param = reinterpret_cast<ConstantOfShapeParameter *>(malloc(sizeof(ConstantOfShapeParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed."; MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(ConstantOfShapeParameter)); memset(param, 0, sizeof(ConstantOfShapeParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
auto attr = primitive->value_as_ConstantOfShape();
if (attr == nullptr) {
MS_LOG(ERROR) << "attr is nullptr";
return nullptr;
}
auto val = attr->value();
if (val == nullptr) {
auto prim_val = value->value();
if (prim_val == nullptr) {
MS_LOG(ERROR) << "val is nullptr"; MS_LOG(ERROR) << "val is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto value = std::vector<float>(val->begin(), val->end());
param->data_type_ = static_cast<int>(attr->data_type());
if (value.empty() || value.size() > 1) {
auto val = std::vector<float>(prim_val->begin(), prim_val->end());
param->data_type_ = static_cast<int>(value->data_type());
if (val.empty() || val.size() > 1) {
MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1.";
} else { } else {
switch (param->data_type_) { switch (param->data_type_) {
case kNumberTypeFloat32: case kNumberTypeFloat32:
param->value_.f32_value_ = *(val->begin());
param->value_.f32_value_ = *(prim_val->begin());
break; break;
case kNumberTypeInt32: case kNumberTypeInt32:
param->value_.int32_value_ = *(val->begin());
param->value_.int32_value_ = *(prim_val->begin());
break; break;
default: default:
MS_LOG(ERROR) << "The value of constant of shape is invalid"; MS_LOG(ERROR) << "The value of constant of shape is invalid";
@@ -57,6 +60,7 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) {
} }
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter, SCHEMA_CUR);
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

+ 44
- 42
mindspore/lite/src/ops/populate/conv2d_populate.cc View File

@@ -20,74 +20,76 @@ using mindspore::schema::PrimitiveType_Conv2DFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateConvParameter(const void *prim) { OpParameter *PopulateConvParameter(const void *prim) {
auto *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (conv_param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr;
}
memset(conv_param, 0, sizeof(ConvParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
conv_param->op_parameter_.type_ = primitive->value_type();
auto conv_primitive = primitive->value_as_Conv2DFusion();
if (conv_primitive == nullptr) {
MS_LOG(ERROR) << "conv_primitive is nullptr";
auto value = primitive->value_as_Conv2DFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
auto kernel_size = conv_primitive->kernel_size();
auto stride = conv_primitive->stride();
auto pad_list = conv_primitive->pad_list();
auto dilation = conv_primitive->dilation();

auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ConvParameter));

param->op_parameter_.type_ = primitive->value_type();
auto kernel_size = value->kernel_size();
auto stride = value->stride();
auto pad_list = value->pad_list();
auto dilation = value->dilation();
if (kernel_size == nullptr || stride == nullptr || dilation == nullptr) { if (kernel_size == nullptr || stride == nullptr || dilation == nullptr) {
MS_LOG(ERROR) << "nullptr"; MS_LOG(ERROR) << "nullptr";
free(param);
return nullptr; return nullptr;
} }
conv_param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
conv_param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
conv_param->group_ = static_cast<int>(conv_primitive->group());
conv_param->stride_h_ = static_cast<int>(*(stride->begin()));
conv_param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
switch (conv_primitive->pad_mode()) {
param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
param->group_ = static_cast<int>(value->group());
param->stride_h_ = static_cast<int>(*(stride->begin()));
param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
switch (value->pad_mode()) {
case schema::PadMode_SAME: case schema::PadMode_SAME:
conv_param->pad_mode_ = Pad_same;
param->pad_mode_ = Pad_same;
break; break;
case schema::PadMode_VALID: case schema::PadMode_VALID:
conv_param->pad_mode_ = Pad_valid;
param->pad_mode_ = Pad_valid;
break; break;
default: default:
conv_param->pad_mode_ = Pad_pad;
param->pad_mode_ = Pad_pad;
} }
if (pad_list == nullptr || pad_list->size() < 4) { if (pad_list == nullptr || pad_list->size() < 4) {
conv_param->pad_u_ = 0;
conv_param->pad_d_ = 0;
conv_param->pad_l_ = 0;
conv_param->pad_r_ = 0;
param->pad_u_ = 0;
param->pad_d_ = 0;
param->pad_l_ = 0;
param->pad_r_ = 0;
} else { } else {
conv_param->pad_u_ = static_cast<int>(*(pad_list->begin()));
conv_param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
conv_param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2));
conv_param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3));
param->pad_u_ = static_cast<int>(*(pad_list->begin()));
param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2));
param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3));
} }
conv_param->dilation_h_ = static_cast<int>(*(dilation->begin()));
conv_param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
conv_param->input_channel_ = static_cast<int>(conv_primitive->in_channel());
conv_param->output_channel_ = static_cast<int>(conv_primitive->out_channel());
auto act_type = conv_primitive->activation_type();
param->dilation_h_ = static_cast<int>(*(dilation->begin()));
param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
param->input_channel_ = static_cast<int>(value->in_channel());
param->output_channel_ = static_cast<int>(value->out_channel());
auto act_type = value->activation_type();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
conv_param->act_type_ = ActType_Relu;
param->act_type_ = ActType_Relu;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
conv_param->act_type_ = ActType_Relu6;
param->act_type_ = ActType_Relu6;
break; break;
default: default:
conv_param->act_type_ = ActType_No;
param->act_type_ = ActType_No;
} }
return reinterpret_cast<OpParameter *>(conv_param);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_Conv2DFusion, PopulateConvParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Conv2DFusion, PopulateConvParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 15
- 14
mindspore/lite/src/ops/populate/crop_and_resize_populate.cc View File

@@ -16,29 +16,30 @@
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "nnacl/resize_parameter.h" #include "nnacl/resize_parameter.h"
using mindspore::schema::PrimitiveType_CropAndResize; using mindspore::schema::PrimitiveType_CropAndResize;

namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateCropAndResizeParameter(const void *prim) { OpParameter *PopulateCropAndResizeParameter(const void *prim) {
auto *crop_resize_param = reinterpret_cast<CropAndResizeParameter *>(malloc(sizeof(CropAndResizeParameter)));
if (crop_resize_param == nullptr) {
MS_LOG(ERROR) << "malloc CropAndResizeParameter failed.";
return nullptr;
}
memset(crop_resize_param, 0, sizeof(CropAndResizeParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
crop_resize_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_CropAndResize();
auto value = primitive->value_as_CropAndResize();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<CropAndResizeParameter *>(malloc(sizeof(CropAndResizeParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc CropAndResizeParameter failed.";
return nullptr; return nullptr;
} }
crop_resize_param->method_ = static_cast<int>(param->method());
crop_resize_param->extrapolation_value_ = param->extrapolation_value();
return reinterpret_cast<OpParameter *>(crop_resize_param);
memset(param, 0, sizeof(CropAndResizeParameter));

param->op_parameter_.type_ = primitive->value_type();
param->method_ = static_cast<int>(value->method());
param->extrapolation_value_ = value->extrapolation_value();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_CropAndResize, PopulateCropAndResizeParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_CropAndResize, PopulateCropAndResizeParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite


+ 21
- 18
mindspore/lite/src/ops/populate/crop_populate.cc View File

@@ -19,39 +19,42 @@ using mindspore::schema::PrimitiveType_Crop;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateCropParameter(const void *prim) { OpParameter *PopulateCropParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto crop_prim = primitive->value_as_Crop();
if (crop_prim == nullptr) {
MS_LOG(ERROR) << "crop_prim is nullptr";
auto value = primitive->value_as_Crop();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
auto param_offset = crop_prim->offsets();

auto *param = reinterpret_cast<CropParameter *>(malloc(sizeof(CropParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc CropParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(CropParameter));

auto param_offset = value->offsets();
if (param_offset == nullptr) { if (param_offset == nullptr) {
MS_LOG(ERROR) << "param_offset is nullptr"; MS_LOG(ERROR) << "param_offset is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (param_offset->size() > COMM_SHAPE_SIZE) { if (param_offset->size() > COMM_SHAPE_SIZE) {
MS_LOG(ERROR) << "crop_param offset size(" << param_offset->size() << ") should <= " << COMM_SHAPE_SIZE;
MS_LOG(ERROR) << "param offset size(" << param_offset->size() << ") should <= " << COMM_SHAPE_SIZE;
free(param);
return nullptr; return nullptr;
} }
auto *crop_param = reinterpret_cast<CropParameter *>(malloc(sizeof(CropParameter)));
if (crop_param == nullptr) {
MS_LOG(ERROR) << "malloc CropParameter failed.";
return nullptr;
}
memset(crop_param, 0, sizeof(CropParameter));
crop_param->op_parameter_.type_ = primitive->value_type();
crop_param->axis_ = crop_prim->axis();
crop_param->offset_size_ = param_offset->size();

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = value->axis();
param->offset_size_ = param_offset->size();
for (size_t i = 0; i < param_offset->size(); ++i) { for (size_t i = 0; i < param_offset->size(); ++i) {
crop_param->offset_[i] = *(param_offset->begin() + i);
param->offset_[i] = *(param_offset->begin() + i);
} }
return reinterpret_cast<OpParameter *>(crop_param);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Crop, PopulateCropParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Crop, PopulateCropParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 15
- 10
mindspore/lite/src/ops/populate/cumsum_populate.cc View File

@@ -19,22 +19,27 @@ using mindspore::schema::PrimitiveType_CumSum;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateCumSumParameter(const void *prim) { OpParameter *PopulateCumSumParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto cumsum_prim = primitive->value_as_CumSum();
CumSumParameter *cumsum_param = reinterpret_cast<CumSumParameter *>(malloc(sizeof(CumSumParameter)));
if (cumsum_param == nullptr) {
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_CumSum();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<CumSumParameter *>(malloc(sizeof(CumSumParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc CumsumParameter failed."; MS_LOG(ERROR) << "malloc CumsumParameter failed.";
return nullptr; return nullptr;
} }
memset(cumsum_param, 0, sizeof(CumSumParameter));
cumsum_param->op_parameter_.type_ = primitive->value_type();
cumsum_param->exclusive_ = cumsum_prim->exclusive();
cumsum_param->reverse_ = cumsum_prim->reverse();
return reinterpret_cast<OpParameter *>(cumsum_param);
memset(param, 0, sizeof(CumSumParameter));

param->op_parameter_.type_ = primitive->value_type();
param->exclusive_ = value->exclusive();
param->reverse_ = value->reverse();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_CumSum, PopulateCumSumParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_CumSum, PopulateCumSumParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 6
- 5
mindspore/lite/src/ops/populate/custom_extract_features_populate.cc View File

@@ -18,20 +18,21 @@ using mindspore::schema::PrimitiveType_CustomExtractFeatures;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateExtractFeaturesParameter(const void *prim) { OpParameter *PopulateExtractFeaturesParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "new OpParameter failed."; MS_LOG(ERROR) << "new OpParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param;
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_CustomExtractFeatures, PopulateExtractFeaturesParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_CustomExtractFeatures, PopulateExtractFeaturesParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 6
- 3
mindspore/lite/src/ops/populate/custom_normalize_populate.cc View File

@@ -19,17 +19,20 @@ using mindspore::schema::PrimitiveType_CustomNormalize;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateCustomNormalizeParameter(const void *prim) { OpParameter *PopulateCustomNormalizeParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "new OpParameter failed."; MS_LOG(ERROR) << "new OpParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param;
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter, SCHEMA_CUR);


} // namespace lite } // namespace lite


+ 9
- 7
mindspore/lite/src/ops/populate/custom_predict_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_CustomPredict;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateCustomPredictParameter(const void *prim) { OpParameter *PopulateCustomPredictParameter(const void *prim) {
PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc param failed.";
return nullptr;
}
memset(param, 0, sizeof(PredictParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_CustomPredict(); auto value = primitive->value_as_CustomPredict();
@@ -33,12 +27,20 @@ OpParameter *PopulateCustomPredictParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }

auto *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc param failed.";
return nullptr;
}
memset(param, 0, sizeof(PredictParameter));

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->output_num = value->output_num(); param->output_num = value->output_num();
param->weight_threshold = value->weight_threshold(); param->weight_threshold = value->weight_threshold();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
REG_POPULATE(PrimitiveType_CustomPredict, PopulateCustomPredictParameter, SCHEMA_CUR);


REG_POPULATE(PrimitiveType_CustomPredict, PopulateCustomPredictParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 46
- 43
mindspore/lite/src/ops/populate/deconv2d_populate.cc View File

@@ -21,74 +21,77 @@ using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateDeconvParameter(const void *prim) { OpParameter *PopulateDeconvParameter(const void *prim) {
auto *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (conv_param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr;
}
memset(conv_param, 0, sizeof(ConvParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
conv_param->op_parameter_.type_ = primitive->value_type();
auto conv_primitive = primitive->value_as_Conv2dTransposeFusion();
if (conv_primitive == nullptr) {
MS_LOG(ERROR) << "conv_primitive is nullptr";
auto value = primitive->value_as_Conv2dTransposeFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr; return nullptr;
} }
auto kernel_size = conv_primitive->kernel_size();
auto stride = conv_primitive->stride();
auto pad_list = conv_primitive->pad_list();
auto dilation = conv_primitive->dilation();
auto output_paddings = conv_primitive->output_paddings();
memset(param, 0, sizeof(ConvParameter));

param->op_parameter_.type_ = primitive->value_type();
auto kernel_size = value->kernel_size();
auto stride = value->stride();
auto pad_list = value->pad_list();
auto dilation = value->dilation();
auto output_paddings = value->output_paddings();
if (kernel_size == nullptr || stride == nullptr || dilation == nullptr || output_paddings == nullptr) { if (kernel_size == nullptr || stride == nullptr || dilation == nullptr || output_paddings == nullptr) {
MS_LOG(ERROR) << "nullptr"; MS_LOG(ERROR) << "nullptr";
free(param);
return nullptr; return nullptr;
} }
conv_param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
conv_param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
conv_param->group_ = static_cast<int>(conv_primitive->group());
conv_param->stride_h_ = static_cast<int>(*(stride->begin()));
conv_param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
conv_param->output_padding_h_ = static_cast<int>(*(output_paddings->begin()));
conv_param->output_padding_w_ = static_cast<int>(*(output_paddings->begin() + 1));
switch (conv_primitive->pad_mode()) {
param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
param->group_ = static_cast<int>(value->group());
param->stride_h_ = static_cast<int>(*(stride->begin()));
param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
param->output_padding_h_ = static_cast<int>(*(output_paddings->begin()));
param->output_padding_w_ = static_cast<int>(*(output_paddings->begin() + 1));
switch (value->pad_mode()) {
case schema::PadMode_SAME: case schema::PadMode_SAME:
conv_param->pad_mode_ = Pad_same;
param->pad_mode_ = Pad_same;
break; break;
case schema::PadMode_VALID: case schema::PadMode_VALID:
conv_param->pad_mode_ = Pad_valid;
param->pad_mode_ = Pad_valid;
break; break;
default: default:
conv_param->pad_mode_ = Pad_pad;
param->pad_mode_ = Pad_pad;
} }
if (pad_list == nullptr || pad_list->size() < 4) { if (pad_list == nullptr || pad_list->size() < 4) {
conv_param->pad_u_ = 0;
conv_param->pad_d_ = 0;
conv_param->pad_l_ = 0;
conv_param->pad_r_ = 0;
param->pad_u_ = 0;
param->pad_d_ = 0;
param->pad_l_ = 0;
param->pad_r_ = 0;
} else { } else {
conv_param->pad_u_ = static_cast<int>(*(pad_list->begin()));
conv_param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
conv_param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2));
conv_param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3));
param->pad_u_ = static_cast<int>(*(pad_list->begin()));
param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2));
param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3));
} }
conv_param->dilation_h_ = static_cast<int>(*(dilation->begin()));
conv_param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
conv_param->input_channel_ = static_cast<int>(conv_primitive->in_channel());
conv_param->output_channel_ = static_cast<int>(conv_primitive->out_channel());
auto act_type = conv_primitive->activation_type();
param->dilation_h_ = static_cast<int>(*(dilation->begin()));
param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
param->input_channel_ = static_cast<int>(value->in_channel());
param->output_channel_ = static_cast<int>(value->out_channel());
auto act_type = value->activation_type();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
conv_param->act_type_ = ActType_Relu;
param->act_type_ = ActType_Relu;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
conv_param->act_type_ = ActType_Relu6;
param->act_type_ = ActType_Relu6;
break; break;
default: default:
conv_param->act_type_ = ActType_No;
param->act_type_ = ActType_No;
break; break;
} }
return reinterpret_cast<OpParameter *>(conv_param);
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_Conv2dTransposeFusion, PopulateDeconvParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Conv2dTransposeFusion, PopulateDeconvParameter, SCHEMA_CUR)


+ 4
- 2
mindspore/lite/src/ops/populate/default_populate.cc View File

@@ -22,14 +22,16 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *DefaultPopulateParameter(const void *prim) { OpParameter *DefaultPopulateParameter(const void *prim) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = static_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = static_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "Malloc OpParameter failed."; MS_LOG(ERROR) << "Malloc OpParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param; return param;
} }


+ 13
- 11
mindspore/lite/src/ops/populate/depth_to_space_populate.cc View File

@@ -21,22 +21,24 @@ namespace mindspore {
namespace lite { namespace lite {
namespace { namespace {
OpParameter *PopulateDepthToSpaceParameter(const void *prim) { OpParameter *PopulateDepthToSpaceParameter(const void *prim) {
auto *depth_space_param = reinterpret_cast<DepthToSpaceParameter *>(malloc(sizeof(DepthToSpaceParameter)));
if (depth_space_param == nullptr) {
MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed.";
return nullptr;
}
memset(depth_space_param, 0, sizeof(DepthToSpaceParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto param = primitive->value_as_DepthToSpace();
auto value = primitive->value_as_DepthToSpace();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<DepthToSpaceParameter *>(malloc(sizeof(DepthToSpaceParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed.";
return nullptr; return nullptr;
} }
depth_space_param->op_parameter_.type_ = primitive->value_type();
depth_space_param->block_size_ = param->block_size();
return reinterpret_cast<OpParameter *>(depth_space_param);
memset(param, 0, sizeof(DepthToSpaceParameter));

param->op_parameter_.type_ = primitive->value_type();
param->block_size_ = value->block_size();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace } // namespace




+ 5
- 4
mindspore/lite/src/ops/populate/depthwise_conv2d_populate.cc View File

@@ -19,13 +19,14 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateConvDwParameter(const void *primitive) { OpParameter *PopulateConvDwParameter(const void *primitive) {
auto *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (conv_param == nullptr) {
auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed."; MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr; return nullptr;
} }
memset(conv_param, 0, sizeof(ConvParameter));
return reinterpret_cast<OpParameter *>(conv_param);
memset(param, 0, sizeof(ConvParameter));

return reinterpret_cast<OpParameter *>(param);
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 26
- 26
mindspore/lite/src/ops/populate/detection_post_process_populate.cc View File

@@ -19,43 +19,43 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateDetectionPostProcessParameter(const void *prim) { OpParameter *PopulateDetectionPostProcessParameter(const void *prim) {
auto *detection_post_process_parameter =
reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter)));
if (detection_post_process_parameter == nullptr) {
MS_LOG(ERROR) << "malloc EluParameter failed.";
return nullptr;
}
memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
detection_post_process_parameter->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_DetectionPostProcess();
auto value = primitive->value_as_DetectionPostProcess();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc DetectionPostProcessParameter failed.";
return nullptr; return nullptr;
} }
auto scale = param->scale();
memset(param, 0, sizeof(DetectionPostProcessParameter));

param->op_parameter_.type_ = primitive->value_type();
auto scale = value->scale();
if (scale == nullptr) { if (scale == nullptr) {
MS_LOG(ERROR) << "scale is nullptr"; MS_LOG(ERROR) << "scale is nullptr";
free(param);
return nullptr; return nullptr;
} }
detection_post_process_parameter->h_scale_ = *(scale->begin());
detection_post_process_parameter->w_scale_ = *(scale->begin() + 1);
detection_post_process_parameter->x_scale_ = *(scale->begin() + 2);
detection_post_process_parameter->y_scale_ = *(scale->begin() + 3);
detection_post_process_parameter->nms_iou_threshold_ = param->nms_iou_threshold();
detection_post_process_parameter->nms_score_threshold_ = param->nms_score_threshold();
detection_post_process_parameter->max_detections_ = param->max_detections();
detection_post_process_parameter->detections_per_class_ = param->detections_per_class();
detection_post_process_parameter->max_classes_per_detection_ = param->max_classes_per_detection();
detection_post_process_parameter->num_classes_ = param->num_classes();
detection_post_process_parameter->use_regular_nms_ = param->use_regular_nms();
return reinterpret_cast<OpParameter *>(detection_post_process_parameter);
param->h_scale_ = *(scale->begin());
param->w_scale_ = *(scale->begin() + 1);
param->x_scale_ = *(scale->begin() + 2);
param->y_scale_ = *(scale->begin() + 3);
param->nms_iou_threshold_ = value->nms_iou_threshold();
param->nms_score_threshold_ = value->nms_score_threshold();
param->max_detections_ = value->max_detections();
param->detections_per_class_ = value->detections_per_class();
param->max_classes_per_detection_ = value->max_classes_per_detection();
param->num_classes_ = value->num_classes();
param->use_regular_nms_ = value->use_regular_nms();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_DetectionPostProcess, PopulateDetectionPostProcessParameter, SCHEMA_CUR);


REG_POPULATE(PrimitiveType_DetectionPostProcess, PopulateDetectionPostProcessParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 2
- 1
mindspore/lite/src/ops/populate/div_populate.cc View File

@@ -21,9 +21,10 @@ namespace lite {
OpParameter *PopulateDivParameter(const void *prim) { OpParameter *PopulateDivParameter(const void *prim) {
auto *param = PopulateArithmeticCommonPara(prim); auto *param = PopulateArithmeticCommonPara(prim);
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
MS_LOG(ERROR) << "get PopulateArithmeticCommonPara failed.";
return nullptr; return nullptr;
} }

return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }




+ 10
- 10
mindspore/lite/src/ops/populate/eltwise_populate.cc View File

@@ -19,24 +19,24 @@ using mindspore::schema::PrimitiveType_Eltwise;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateEltwiseParameter(const void *prim) { OpParameter *PopulateEltwiseParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Eltwise();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); ArithmeticParameter *param = PopulateArithmeticCommonPara(prim);
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr; return nullptr;
} }
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto eltwise_param = primitive->value_as_Eltwise();
if (eltwise_param == nullptr) {
MS_LOG(ERROR) << "eltwise_param is nullptr";
return nullptr;
}
param->eltwise_mode_ = eltwise_param->mode();

param->eltwise_mode_ = value->mode();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Eltwise, PopulateEltwiseParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Eltwise, PopulateEltwiseParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 14
- 13
mindspore/lite/src/ops/populate/elu_populate.cc View File

@@ -19,26 +19,27 @@ using mindspore::schema::PrimitiveType_Elu;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateEluParameter(const void *prim) { OpParameter *PopulateEluParameter(const void *prim) {
auto *elu_parameter = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter)));
if (elu_parameter == nullptr) {
MS_LOG(ERROR) << "malloc EluParameter failed.";
return nullptr;
}
memset(elu_parameter, 0, sizeof(EluParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
elu_parameter->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_Elu();
auto value = primitive->value_as_Elu();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc EluParameter failed.";
return nullptr; return nullptr;
} }
elu_parameter->alpha_ = param->alpha();
return reinterpret_cast<OpParameter *>(elu_parameter);
memset(param, 0, sizeof(EluParameter));

param->op_parameter_.type_ = primitive->value_type();
param->alpha_ = value->alpha();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_Elu, PopulateEluParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Elu, PopulateEluParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 8
- 8
mindspore/lite/src/ops/populate/embedding_lookup_populate.cc View File

@@ -19,15 +19,7 @@ using mindspore::schema::PrimitiveType_EmbeddingLookupFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

OpParameter *PopulateEmbeddingLookupParameter(const void *prim) { OpParameter *PopulateEmbeddingLookupParameter(const void *prim) {
auto *param = reinterpret_cast<EmbeddingLookupParameter *>(malloc(sizeof(EmbeddingLookupParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(EmbeddingLookupParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_EmbeddingLookupFusion(); auto value = primitive->value_as_EmbeddingLookupFusion();
@@ -35,6 +27,14 @@ OpParameter *PopulateEmbeddingLookupParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }

auto *param = reinterpret_cast<EmbeddingLookupParameter *>(malloc(sizeof(EmbeddingLookupParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(EmbeddingLookupParameter));

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->max_norm_ = value->max_norm(); param->max_norm_ = value->max_norm();
if (param->max_norm_ < 0) { if (param->max_norm_ < 0) {


+ 16
- 15
mindspore/lite/src/ops/populate/exp_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_ExpFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateExpParameter(const void *prim) { OpParameter *PopulateExpParameter(const void *prim) {
auto *exp_parameter = reinterpret_cast<ExpParameter *>(malloc(sizeof(ExpParameter)));
if (exp_parameter == nullptr) {
MS_LOG(ERROR) << "malloc ExpParameter failed.";
return nullptr;
}
memset(exp_parameter, 0, sizeof(ExpParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ExpFusion(); auto value = primitive->value_as_ExpFusion();
@@ -34,16 +27,24 @@ OpParameter *PopulateExpParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
exp_parameter->op_parameter_.type_ = primitive->value_type();
exp_parameter->base_ = value->base();
exp_parameter->scale_ = value->scale();
exp_parameter->shift_ = value->shift();
if (exp_parameter->base_ != -1 && exp_parameter->base_ <= 0) {
MS_LOG(ERROR) << "Exp base must be strictly positive, got " << exp_parameter->base_;
free(exp_parameter);

auto *param = reinterpret_cast<ExpParameter *>(malloc(sizeof(ExpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ExpParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ExpParameter));

param->op_parameter_.type_ = primitive->value_type();
param->base_ = value->base();
param->scale_ = value->scale();
param->shift_ = value->shift();
if (param->base_ != -1 && param->base_ <= 0) {
MS_LOG(ERROR) << "Exp base must be strictly positive, got " << param->base_;
free(param);
return nullptr; return nullptr;
} }
return reinterpret_cast<OpParameter *>(exp_parameter);
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_ExpFusion, PopulateExpParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ExpFusion, PopulateExpParameter, SCHEMA_CUR)


+ 9
- 9
mindspore/lite/src/ops/populate/expand_dims_populate.cc View File

@@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_ExpandDims;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateExpandDimsParameter(const void *prim) { OpParameter *PopulateExpandDimsParameter(const void *prim) {
auto *expand_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (expand_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ExpandDimsParameter failed."; MS_LOG(ERROR) << "malloc ExpandDimsParameter failed.";
return nullptr; return nullptr;
} }
memset(expand_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
expand_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(expand_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_ExpandDims, PopulateExpandDimsParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ExpandDims, PopulateExpandDimsParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 9
- 9
mindspore/lite/src/ops/populate/fill_populate.cc View File

@@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Fill;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateFillParameter(const void *prim) { OpParameter *PopulateFillParameter(const void *prim) {
auto *fill_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (fill_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc FillParameter failed."; MS_LOG(ERROR) << "malloc FillParameter failed.";
return nullptr; return nullptr;
} }
memset(fill_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
fill_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(fill_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Fill, PopulateFillParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Fill, PopulateFillParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 8
- 7
mindspore/lite/src/ops/populate/flatten_populate.cc View File

@@ -19,17 +19,18 @@ using mindspore::schema::PrimitiveType_Flatten;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateFlattenParameter(const void *prim) { OpParameter *PopulateFlattenParameter(const void *prim) {
auto *flatten_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (flatten_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc FlattenParameter failed."; MS_LOG(ERROR) << "malloc FlattenParameter failed.";
return nullptr; return nullptr;
} }
memset(flatten_param, 0, sizeof(OpParameter));
memset(param, 0, sizeof(OpParameter));


auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
flatten_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(flatten_param);
param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_Flatten, PopulateFlattenParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Flatten, PopulateFlattenParameter, SCHEMA_CUR)


+ 23
- 23
mindspore/lite/src/ops/populate/full_connection_populate.cc View File

@@ -19,37 +19,37 @@ using mindspore::schema::PrimitiveType_FullConnection;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateFullconnectionParameter(const void *prim) { OpParameter *PopulateFullconnectionParameter(const void *prim) {
auto *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter)));
if (matmul_param == nullptr) {
MS_LOG(ERROR) << "malloc MatMulParameter failed.";
return nullptr;
}
memset(matmul_param, 0, sizeof(MatMulParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
matmul_param->op_parameter_.type_ = primitive->value_type();
auto full_conn_prim = primitive->value_as_FullConnection();
if (full_conn_prim == nullptr) {
MS_LOG(ERROR) << "full_conn_prim is nullptr";
auto value = primitive->value_as_FullConnection();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc MatMulParameter failed.";
return nullptr; return nullptr;
} }
matmul_param->b_transpose_ = true;
matmul_param->a_transpose_ = false;
matmul_param->has_bias_ = full_conn_prim->has_bias();
if (full_conn_prim->activation_type() == schema::ActivationType_RELU) {
matmul_param->act_type_ = ActType_Relu;
} else if (full_conn_prim->activation_type() == schema::ActivationType_RELU6) {
matmul_param->act_type_ = ActType_Relu6;
memset(param, 0, sizeof(MatMulParameter));

param->op_parameter_.type_ = primitive->value_type();
param->b_transpose_ = true;
param->a_transpose_ = false;
param->has_bias_ = value->has_bias();
if (value->activation_type() == schema::ActivationType_RELU) {
param->act_type_ = ActType_Relu;
} else if (value->activation_type() == schema::ActivationType_RELU6) {
param->act_type_ = ActType_Relu6;
} else { } else {
matmul_param->act_type_ = ActType_No;
param->act_type_ = ActType_No;
} }
matmul_param->axis_ = full_conn_prim->axis();
matmul_param->use_axis_ = full_conn_prim->use_axis();
return reinterpret_cast<OpParameter *>(matmul_param);
param->axis_ = value->axis();
param->use_axis_ = value->use_axis();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_FullConnection, PopulateFullconnectionParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_FullConnection, PopulateFullconnectionParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 13
- 11
mindspore/lite/src/ops/populate/fused_batchnorm_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_FusedBatchNorm;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateFusedBatchNorm(const void *prim) { OpParameter *PopulateFusedBatchNorm(const void *prim) {
auto *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (batch_norm_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
return nullptr;
}
memset(batch_norm_param, 0, sizeof(BatchNormParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_FusedBatchNorm(); auto value = primitive->value_as_FusedBatchNorm();
@@ -33,11 +27,19 @@ OpParameter *PopulateFusedBatchNorm(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
batch_norm_param->op_parameter_.type_ = primitive->value_type();
batch_norm_param->epsilon_ = value->epsilon();
batch_norm_param->momentum_ = value->momentum();
batch_norm_param->fused_ = true;
return reinterpret_cast<OpParameter *>(batch_norm_param);

auto *param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BatchNormParameter));

param->op_parameter_.type_ = primitive->value_type();
param->epsilon_ = value->epsilon();
param->momentum_ = value->momentum();
param->fused_ = true;
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm, SCHEMA_CUR) REG_POPULATE(PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm, SCHEMA_CUR)


+ 9
- 9
mindspore/lite/src/ops/populate/gather_nd_populate.cc View File

@@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_GatherNd;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateGatherNdParameter(const void *prim) { OpParameter *PopulateGatherNdParameter(const void *prim) {
auto *gather_nd_param = reinterpret_cast<GatherNdParameter *>(malloc(sizeof(GatherNdParameter)));
if (gather_nd_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<GatherNdParameter *>(malloc(sizeof(GatherNdParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc GatherNdParameter failed."; MS_LOG(ERROR) << "malloc GatherNdParameter failed.";
return nullptr; return nullptr;
} }
memset(gather_nd_param, 0, sizeof(GatherNdParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
gather_nd_param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(gather_nd_param);
memset(param, 0, sizeof(GatherNdParameter));

param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_GatherNd, PopulateGatherNdParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_GatherNd, PopulateGatherNdParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 8
- 9
mindspore/lite/src/ops/populate/gather_populate.cc View File

@@ -19,21 +19,20 @@ using mindspore::schema::PrimitiveType_Gather;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateGatherParameter(const void *prim) { OpParameter *PopulateGatherParameter(const void *prim) {
auto *gather_param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter)));
if (gather_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc GatherParameter failed."; MS_LOG(ERROR) << "malloc GatherParameter failed.";
return nullptr; return nullptr;
} }
memset(gather_param, 0, sizeof(GatherParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
gather_param->op_parameter_.type_ = primitive->value_type();
memset(param, 0, sizeof(GatherParameter));


return reinterpret_cast<OpParameter *>(gather_param);
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Gather, PopulateGatherParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Gather, PopulateGatherParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 13
- 14
mindspore/lite/src/ops/populate/gru_populate.cc View File

@@ -19,27 +19,26 @@ using mindspore::schema::PrimitiveType_GRU;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateGruParameter(const void *prim) { OpParameter *PopulateGruParameter(const void *prim) {
auto *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter)));
if (gru_param == nullptr) {
MS_LOG(ERROR) << "malloc GruParameter failed.";
return nullptr;
}
memset(gru_param, 0, sizeof(GruParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
gru_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_GRU();
if (param == nullptr) {
free(gru_param);
auto value = primitive->value_as_GRU();
if (value == nullptr) {
MS_LOG(ERROR) << "param is nullptr."; MS_LOG(ERROR) << "param is nullptr.";
return nullptr; return nullptr;
} }
gru_param->bidirectional_ = param->bidirectional();
return reinterpret_cast<OpParameter *>(gru_param);

auto *param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc GruParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(GruParameter));

param->op_parameter_.type_ = primitive->value_type();
param->bidirectional_ = value->bidirectional();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_GRU, PopulateGruParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_GRU, PopulateGruParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 5
- 2
mindspore/lite/src/ops/populate/hashtable_lookup_populate.cc View File

@@ -19,17 +19,20 @@ using mindspore::schema::PrimitiveType_HashtableLookup;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateHashtableLookupParameter(const void *prim) { OpParameter *PopulateHashtableLookupParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "new OpParameter failed."; MS_LOG(ERROR) << "new OpParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param; return param;
} }

REG_POPULATE(PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter, SCHEMA_CUR);


} // namespace lite } // namespace lite


+ 11
- 10
mindspore/lite/src/ops/populate/instance_norm_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_InstanceNorm;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateInstanceNormParameter(const void *prim) { OpParameter *PopulateInstanceNormParameter(const void *prim) {
auto *instance_norm_param = reinterpret_cast<InstanceNormParameter *>(malloc(sizeof(InstanceNormParameter)));
if (instance_norm_param == nullptr) {
MS_LOG(ERROR) << "malloc InstanceNormParameter failed.";
return nullptr;
}
memset(instance_norm_param, 0, sizeof(InstanceNormParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_InstanceNorm(); auto value = primitive->value_as_InstanceNorm();
@@ -34,9 +27,17 @@ OpParameter *PopulateInstanceNormParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
instance_norm_param->op_parameter_.type_ = primitive->value_type();
instance_norm_param->epsilon_ = value->epsilon();
return reinterpret_cast<OpParameter *>(instance_norm_param);

auto *param = reinterpret_cast<InstanceNormParameter *>(malloc(sizeof(InstanceNormParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc InstanceNormParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(InstanceNormParameter));

param->op_parameter_.type_ = primitive->value_type();
param->epsilon_ = value->epsilon();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_InstanceNorm, PopulateInstanceNormParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_InstanceNorm, PopulateInstanceNormParameter, SCHEMA_CUR)


+ 17
- 17
mindspore/lite/src/ops/populate/l2_norm_populate.cc View File

@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <cstdint>
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "nnacl/l2_norm_parameter.h" #include "nnacl/l2_norm_parameter.h"
using mindspore::schema::PrimitiveType_L2NormalizeFusion; using mindspore::schema::PrimitiveType_L2NormalizeFusion;
@@ -21,13 +20,6 @@ using mindspore::schema::PrimitiveType_L2NormalizeFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateL2NormParameter(const void *prim) { OpParameter *PopulateL2NormParameter(const void *prim) {
auto *l2_norm_parameter = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
if (l2_norm_parameter == nullptr) {
MS_LOG(ERROR) << "malloc L2NormParameter failed.";
return nullptr;
}
memset(l2_norm_parameter, 0, sizeof(L2NormParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_L2NormalizeFusion(); auto value = primitive->value_as_L2NormalizeFusion();
@@ -35,32 +27,40 @@ OpParameter *PopulateL2NormParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
l2_norm_parameter->op_parameter_.type_ = primitive->value_type();


auto *param = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc L2NormParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(L2NormParameter));

param->op_parameter_.type_ = primitive->value_type();
auto axis_vec = value->axis(); auto axis_vec = value->axis();
if (axis_vec == nullptr) { if (axis_vec == nullptr) {
MS_LOG(ERROR) << "axis_vec is nullptr"; MS_LOG(ERROR) << "axis_vec is nullptr";
free(param);
return nullptr; return nullptr;
} }
l2_norm_parameter->axis_num_ = axis_vec->size();
param->axis_num_ = axis_vec->size();


MS_ASSERT(axis_vec->size() < 8); MS_ASSERT(axis_vec->size() < 8);
for (size_t i = 0; i < axis_vec->size(); i++) { for (size_t i = 0; i < axis_vec->size(); i++) {
l2_norm_parameter->axis_[i] = static_cast<int>(axis_vec->Get(i));
param->axis_[i] = static_cast<int>(axis_vec->Get(i));
} }
if (value->epsilon() < 1e-6) { if (value->epsilon() < 1e-6) {
l2_norm_parameter->epsilon_ = 1e-6;
param->epsilon_ = 1e-6;
} else { } else {
l2_norm_parameter->epsilon_ = value->epsilon();
param->epsilon_ = value->epsilon();
} }
if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU)) { if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU)) {
l2_norm_parameter->act_type_ = ActType_Relu;
param->act_type_ = ActType_Relu;
} else if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU6)) { } else if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU6)) {
l2_norm_parameter->act_type_ = ActType_Relu6;
param->act_type_ = ActType_Relu6;
} else { } else {
l2_norm_parameter->act_type_ = ActType_No;
param->act_type_ = ActType_No;
} }
return reinterpret_cast<OpParameter *>(l2_norm_parameter);
return reinterpret_cast<OpParameter *>(param);
} }
REG_POPULATE(PrimitiveType_L2NormalizeFusion, PopulateL2NormParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_L2NormalizeFusion, PopulateL2NormParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 14
- 12
mindspore/lite/src/ops/populate/layer_norm_grad_populate.cc View File

@@ -21,23 +21,25 @@ using mindspore::schema::PrimitiveType_LayerNormGrad;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateLayerNormGradParameter(const void *prim) { OpParameter *PopulateLayerNormGradParameter(const void *prim) {
auto layer_norm_grad_parameter = reinterpret_cast<LayerNormGradParameter *>(malloc(sizeof(LayerNormGradParameter)));
if (layer_norm_grad_parameter == nullptr) {
MS_LOG(ERROR) << "malloc LayerNormParameter failed.";
return nullptr;
}
memset(layer_norm_grad_parameter, 0, sizeof(LayerNormGradParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
layer_norm_grad_parameter->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_LayerNormGrad();
auto value = primitive->value_as_LayerNormGrad();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto param = reinterpret_cast<LayerNormGradParameter *>(malloc(sizeof(LayerNormGradParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc LayerNormParameter failed.";
return nullptr; return nullptr;
} }
layer_norm_grad_parameter->begin_norm_axis_ = param->begin_norm_axis();
layer_norm_grad_parameter->begin_params_axis_ = param->begin_params_axis();
return reinterpret_cast<OpParameter *>(layer_norm_grad_parameter);
memset(param, 0, sizeof(LayerNormGradParameter));

param->op_parameter_.type_ = primitive->value_type();
param->begin_norm_axis_ = value->begin_norm_axis();
param->begin_params_axis_ = value->begin_params_axis();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_LayerNormGrad, PopulateLayerNormGradParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_LayerNormGrad, PopulateLayerNormGradParameter, SCHEMA_CUR);


+ 17
- 14
mindspore/lite/src/ops/populate/layer_norm_populate.cc View File

@@ -17,28 +17,31 @@
#include <cstdint> #include <cstdint>
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
using mindspore::schema::PrimitiveType_LayerNormFusion; using mindspore::schema::PrimitiveType_LayerNormFusion;

namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateLayerNormParameter(const void *prim) { OpParameter *PopulateLayerNormParameter(const void *prim) {
auto layer_norm_parameter = reinterpret_cast<LayerNormParameter *>(malloc(sizeof(LayerNormParameter)));
if (layer_norm_parameter == nullptr) {
MS_LOG(ERROR) << "malloc LayerNormParameter failed.";
return nullptr;
}
memset(layer_norm_parameter, 0, sizeof(LayerNormParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
layer_norm_parameter->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_LayerNormFusion();
auto value = primitive->value_as_LayerNormFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto param = reinterpret_cast<LayerNormParameter *>(malloc(sizeof(LayerNormParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc LayerNormParameter failed.";
return nullptr; return nullptr;
} }
layer_norm_parameter->epsilon_ = param->epsilon();
layer_norm_parameter->elementwise_affine_ = param->elementwise_affine();
layer_norm_parameter->begin_norm_axis_ = static_cast<int>(param->begin_norm_axis());
layer_norm_parameter->begin_params_axis_ = static_cast<int>(param->begin_params_axis());
return reinterpret_cast<OpParameter *>(layer_norm_parameter);
memset(param, 0, sizeof(LayerNormParameter));

param->op_parameter_.type_ = primitive->value_type();
param->epsilon_ = value->epsilon();
param->elementwise_affine_ = value->elementwise_affine();
param->begin_norm_axis_ = static_cast<int>(value->begin_norm_axis());
param->begin_params_axis_ = static_cast<int>(value->begin_params_axis());
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_LayerNormFusion, PopulateLayerNormParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_LayerNormFusion, PopulateLayerNormParameter, SCHEMA_CUR)


+ 14
- 12
mindspore/lite/src/ops/populate/local_response_normalization_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_LRN;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateLocalResponseNormParameter(const void *prim) { OpParameter *PopulateLocalResponseNormParameter(const void *prim) {
auto *lrn_param = reinterpret_cast<LocalResponseNormParameter *>(malloc(sizeof(LocalResponseNormParameter)));
if (lrn_param == nullptr) {
MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed.";
return nullptr;
}
memset(lrn_param, 0, sizeof(LocalResponseNormParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_LRN(); auto value = primitive->value_as_LRN();
@@ -33,12 +27,20 @@ OpParameter *PopulateLocalResponseNormParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
lrn_param->op_parameter_.type_ = primitive->value_type();
lrn_param->depth_radius_ = value->depth_radius();
lrn_param->bias_ = value->bias();
lrn_param->alpha_ = value->alpha();
lrn_param->beta_ = value->beta();
return reinterpret_cast<OpParameter *>(lrn_param);

auto *param = reinterpret_cast<LocalResponseNormParameter *>(malloc(sizeof(LocalResponseNormParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(LocalResponseNormParameter));

param->op_parameter_.type_ = primitive->value_type();
param->depth_radius_ = value->depth_radius();
param->bias_ = value->bias();
param->alpha_ = value->alpha();
param->beta_ = value->beta();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_LRN, PopulateLocalResponseNormParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_LRN, PopulateLocalResponseNormParameter, SCHEMA_CUR);


+ 14
- 14
mindspore/lite/src/ops/populate/log_softmax_populate.cc View File

@@ -19,26 +19,26 @@ using mindspore::schema::PrimitiveType_LogSoftmax;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateLogSoftmaxParameter(const void *prim) { OpParameter *PopulateLogSoftmaxParameter(const void *prim) {
auto *log_softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter)));
if (log_softmax_param == nullptr) {
MS_LOG(ERROR) << "malloc LogSoftmaxParameter failed.";
return nullptr;
}
memset(log_softmax_param, 0, sizeof(SoftmaxParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
log_softmax_param->op_parameter_.type_ = primitive->value_type();
auto prim_log_softmax = primitive->value_as_LogSoftmax();
if (prim_log_softmax == nullptr) {
MS_LOG(ERROR) << "prim_log_softmax is nullptr";
auto value = primitive->value_as_LogSoftmax();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SoftmaxParameter failed.";
return nullptr; return nullptr;
} }
log_softmax_param->axis_ = prim_log_softmax->axis();
return reinterpret_cast<OpParameter *>(log_softmax_param);
memset(param, 0, sizeof(SoftmaxParameter));

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = value->axis();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_LogSoftmax, PopulateLogSoftmaxParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_LogSoftmax, PopulateLogSoftmaxParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite


+ 12
- 11
mindspore/lite/src/ops/populate/lsh_projection_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_LshProjection;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateLshProjectionParameter(const void *prim) { OpParameter *PopulateLshProjectionParameter(const void *prim) {
auto *lsh_project_param = reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter)));
if (lsh_project_param == nullptr) {
MS_LOG(ERROR) << "malloc LshProjectionParameter failed.";
return nullptr;
}
memset(lsh_project_param, 0, sizeof(LshProjectionParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_LshProjection(); auto value = primitive->value_as_LshProjection();
@@ -34,11 +27,19 @@ OpParameter *PopulateLshProjectionParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
lsh_project_param->op_parameter_.type_ = primitive->value_type();
lsh_project_param->lsh_type_ = value->type();
return reinterpret_cast<OpParameter *>(lsh_project_param);

auto *param = reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc LshProjectionParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(LshProjectionParameter));

param->op_parameter_.type_ = primitive->value_type();
param->lsh_type_ = value->type();
return reinterpret_cast<OpParameter *>(param);
} }
REG_POPULATE(PrimitiveType_LshProjection, PopulateLshProjectionParameter, SCHEMA_CUR);


REG_POPULATE(PrimitiveType_LshProjection, PopulateLshProjectionParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 15
- 16
mindspore/lite/src/ops/populate/lstm_populate.cc View File

@@ -19,30 +19,29 @@ using mindspore::schema::PrimitiveType_LSTM;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateLstmParameter(const void *prim) { OpParameter *PopulateLstmParameter(const void *prim) {
auto *lstm_param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
if (lstm_param == nullptr) {
MS_LOG(ERROR) << "malloc LstmParameter failed.";
return nullptr;
}
memset(lstm_param, 0, sizeof(LstmParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
lstm_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_LSTM();
auto value = primitive->value_as_LSTM();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr.";
return nullptr;
}

auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
if (param == nullptr) { if (param == nullptr) {
free(lstm_param);
MS_LOG(ERROR) << "get Lstm param nullptr.";
MS_LOG(ERROR) << "malloc LstmParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(LstmParameter));


lstm_param->bidirectional_ = param->bidirectional();
lstm_param->zoneout_cell_ = param->zoneout_cell();
lstm_param->zoneout_hidden_ = param->zoneout_hidden();
return reinterpret_cast<OpParameter *>(lstm_param);
param->op_parameter_.type_ = primitive->value_type();
param->bidirectional_ = value->bidirectional();
param->zoneout_cell_ = value->zoneout_cell();
param->zoneout_hidden_ = value->zoneout_hidden();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_LSTM, PopulateLstmParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_LSTM, PopulateLstmParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 16
- 12
mindspore/lite/src/ops/populate/matmul_populate.cc View File

@@ -16,15 +16,10 @@
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "nnacl/matmul_parameter.h" #include "nnacl/matmul_parameter.h"
using mindspore::schema::PrimitiveType_MatMul; using mindspore::schema::PrimitiveType_MatMul;

namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateMatMulParameter(const void *prim) { OpParameter *PopulateMatMulParameter(const void *prim) {
auto *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter)));
if (matmul_param == nullptr) {
MS_LOG(ERROR) << "malloc MatMulParameter failed.";
return nullptr;
}
memset(matmul_param, 0, sizeof(MatMulParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_MatMul(); auto value = primitive->value_as_MatMul();
@@ -32,13 +27,22 @@ OpParameter *PopulateMatMulParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
matmul_param->op_parameter_.type_ = primitive->value_type();
matmul_param->b_transpose_ = value->transpose_b();
matmul_param->a_transpose_ = value->transpose_a();
matmul_param->has_bias_ = false;
matmul_param->act_type_ = ActType_No;
return reinterpret_cast<OpParameter *>(matmul_param);

auto *param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc MatMulParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(MatMulParameter));

param->op_parameter_.type_ = primitive->value_type();
param->b_transpose_ = value->transpose_b();
param->a_transpose_ = value->transpose_a();
param->has_bias_ = false;
param->act_type_ = ActType_No;
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_MatMul, PopulateMatMulParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_MatMul, PopulateMatMulParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 10
- 8
mindspore/lite/src/ops/populate/merge_populate.cc View File

@@ -19,16 +19,18 @@ using mindspore::schema::PrimitiveType_Merge;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateMergeParameter(const void *prim) { OpParameter *PopulateMergeParameter(const void *prim) {
auto *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (merge_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Merge parameter failed.";
return nullptr;
}
memset(merge_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
merge_parameter->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(merge_parameter);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
REG_POPULATE(PrimitiveType_Merge, PopulateMergeParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Merge, PopulateMergeParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 13
- 13
mindspore/lite/src/ops/populate/mfcc_populate.cc View File

@@ -19,26 +19,26 @@ using mindspore::schema::PrimitiveType_Mfcc;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateMfccParameter(const void *prim) { OpParameter *PopulateMfccParameter(const void *prim) {
auto *arg_param = reinterpret_cast<MfccParameter *>(malloc(sizeof(MfccParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc MfccParameter failed.";
return nullptr;
}
memset(arg_param, 0, sizeof(MfccParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
arg_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_Mfcc();
auto value = primitive->value_as_Mfcc();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<MfccParameter *>(malloc(sizeof(MfccParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc MfccParameter failed.";
return nullptr; return nullptr;
} }
arg_param->dct_coeff_num_ = param->dct_coeff_num();
return reinterpret_cast<OpParameter *>(arg_param);
memset(param, 0, sizeof(MfccParameter));

param->op_parameter_.type_ = primitive->value_type();
param->dct_coeff_num_ = value->dct_coeff_num();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Mfcc, PopulateMfccParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Mfcc, PopulateMfccParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 4
- 4
mindspore/lite/src/ops/populate/mul_populate.cc View File

@@ -20,19 +20,19 @@ using mindspore::schema::PrimitiveType_MulFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateMulParameter(const void *prim) { OpParameter *PopulateMulParameter(const void *prim) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); ArithmeticParameter *param = PopulateArithmeticCommonPara(prim);
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr; return nullptr;
} }
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_MulFusion, PopulateMulParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_MulFusion, PopulateMulParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 8
- 6
mindspore/lite/src/ops/populate/non_max_suppression_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_NonMaxSuppression;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateNonMaxSuppressionParameter(const void *prim) { OpParameter *PopulateNonMaxSuppressionParameter(const void *prim) {
auto *param = reinterpret_cast<NMSParameter *>(malloc(sizeof(NMSParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc param failed.";
return nullptr;
}
memset(param, 0, sizeof(NMSParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_NonMaxSuppression(); auto value = primitive->value_as_NonMaxSuppression();
@@ -33,6 +27,14 @@ OpParameter *PopulateNonMaxSuppressionParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }

auto *param = reinterpret_cast<NMSParameter *>(malloc(sizeof(NMSParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc NMSParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(NMSParameter));

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->center_point_box_ = value->center_point_box(); param->center_point_box_ = value->center_point_box();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);


+ 12
- 10
mindspore/lite/src/ops/populate/one_hot_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_OneHot;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateOneHotParameter(const void *prim) { OpParameter *PopulateOneHotParameter(const void *prim) {
auto *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter)));
if (one_hot_param == nullptr) {
MS_LOG(ERROR) << "malloc OneHotParameter failed.";
return nullptr;
}
memset(one_hot_param, 0, sizeof(OneHotParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_OneHot(); auto value = primitive->value_as_OneHot();
@@ -34,10 +27,19 @@ OpParameter *PopulateOneHotParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
one_hot_param->op_parameter_.type_ = primitive->value_type();
one_hot_param->axis_ = value->axis();
return reinterpret_cast<OpParameter *>(one_hot_param);

auto *param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc OneHotParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OneHotParameter));

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = value->axis();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_OneHot, PopulateOneHotParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_OneHot, PopulateOneHotParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 5
- 3
mindspore/lite/src/ops/populate/oneslike_populate.cc View File

@@ -19,14 +19,16 @@ using mindspore::schema::PrimitiveType_OnesLike;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateOnesLikeParameter(const void *prim) { OpParameter *PopulateOnesLikeParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc OnesLike Parameter failed.";
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(OpParameter)); memset(param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

param->type_ = primitive->value_type(); param->type_ = primitive->value_type();
return param; return param;
} }


+ 9
- 6
mindspore/lite/src/ops/populate/p_relu_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_PReLUFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulatePReLUParameter(const void *prim) { OpParameter *PopulatePReLUParameter(const void *prim) {
PReluParameter *param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PReluParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(PReluParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_PReLUFusion(); auto value = primitive->value_as_PReLUFusion();
@@ -33,10 +27,19 @@ OpParameter *PopulatePReLUParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }

auto *param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PReluParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(PReluParameter));

param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->channelShared = value->channel_shared(); param->channelShared = value->channel_shared();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_PReLUFusion, PopulatePReLUParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_PReLUFusion, PopulatePReLUParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 13
- 10
mindspore/lite/src/ops/populate/pad_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_PadFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulatePadParameter(const void *prim) { OpParameter *PopulatePadParameter(const void *prim) {
auto *pad_param = reinterpret_cast<PadParameter *>(malloc(sizeof(PadParameter)));
if (pad_param == nullptr) {
MS_LOG(ERROR) << "malloc PadParameter failed.";
return nullptr;
}
memset(pad_param, 0, sizeof(PadParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_PadFusion(); auto value = primitive->value_as_PadFusion();
@@ -33,11 +27,20 @@ OpParameter *PopulatePadParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
pad_param->op_parameter_.type_ = primitive->value_type();
pad_param->pad_mode_ = value->padding_mode();
pad_param->constant_value_ = value->constant_value();
return reinterpret_cast<OpParameter *>(pad_param);

auto *param = reinterpret_cast<PadParameter *>(malloc(sizeof(PadParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PadParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(PadParameter));

param->op_parameter_.type_ = primitive->value_type();
param->pad_mode_ = value->padding_mode();
param->constant_value_ = value->constant_value();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_PadFusion, PopulatePadParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_PadFusion, PopulatePadParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 11
- 10
mindspore/lite/src/ops/populate/partial_populate.cc View File

@@ -19,14 +19,7 @@ using mindspore::schema::PrimitiveType_PartialFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

OpParameter *PopulatePartialParameter(const void *prim) { OpParameter *PopulatePartialParameter(const void *prim) {
auto *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter)));
if (partial_parameter == nullptr) {
MS_LOG(ERROR) << "malloc partial parameter failed.";
return nullptr;
}
memset(partial_parameter, 0, sizeof(PartialParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_PartialFusion(); auto value = primitive->value_as_PartialFusion();
@@ -34,11 +27,19 @@ OpParameter *PopulatePartialParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
partial_parameter->op_parameter_.type_ = primitive->value_type();
partial_parameter->sub_graph_index_ = value->sub_graph_index();


return reinterpret_cast<OpParameter *>(partial_parameter);
auto *param = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc partial parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(PartialParameter));

param->op_parameter_.type_ = primitive->value_type();
param->sub_graph_index_ = value->sub_graph_index();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_PartialFusion, PopulatePartialParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_PartialFusion, PopulatePartialParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 83
- 78
mindspore/lite/src/ops/populate/pooling_populate.cc View File

@@ -20,155 +20,160 @@ using mindspore::schema::PrimitiveType_MaxPoolFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateAvgPoolParameter(const void *primitive) { OpParameter *PopulateAvgPoolParameter(const void *primitive) {
auto *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
if (pooling_param == nullptr) {
MS_LOG(ERROR) << "malloc PoolingParameter failed.";
return nullptr;
}
memset(pooling_param, 0, sizeof(PoolingParameter));
auto pooling_prim = static_cast<const schema::Primitive *>(primitive); auto pooling_prim = static_cast<const schema::Primitive *>(primitive);
MS_ASSERT(pooling_prim != nullptr); MS_ASSERT(pooling_prim != nullptr);
pooling_param->op_parameter_.type_ = pooling_prim->value_type();
auto pooling_primitive = pooling_prim->value_as_AvgPoolFusion();
if (pooling_primitive == nullptr) {
MS_LOG(ERROR) << "pooling_primitive is nullptr";
auto value = pooling_prim->value_as_AvgPoolFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PoolingParameter failed.";
return nullptr; return nullptr;
} }
pooling_param->pool_mode_ = PoolMode_AvgPool;
pooling_param->global_ = pooling_primitive->global();
auto strides = pooling_primitive->strides();
memset(param, 0, sizeof(PoolingParameter));

param->op_parameter_.type_ = pooling_prim->value_type();
param->pool_mode_ = PoolMode_AvgPool;
param->global_ = value->global();
auto strides = value->strides();
if (strides == nullptr) { if (strides == nullptr) {
MS_LOG(ERROR) << "strides is nullptr"; MS_LOG(ERROR) << "strides is nullptr";
free(param);
return nullptr; return nullptr;
} }
pooling_param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
pooling_param->stride_h_ = static_cast<int>(*(strides->begin()));
auto pad = pooling_primitive->pad();
param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
param->stride_h_ = static_cast<int>(*(strides->begin()));
auto pad = value->pad();
if (pad != nullptr) { if (pad != nullptr) {
pooling_param->pad_u_ = static_cast<int>(*(pad->begin()));
pooling_param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
pooling_param->pad_l_ = static_cast<int>(*(pad->begin() + 2));
pooling_param->pad_r_ = static_cast<int>(*(pad->begin() + 3));
param->pad_u_ = static_cast<int>(*(pad->begin()));
param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
param->pad_l_ = static_cast<int>(*(pad->begin() + 2));
param->pad_r_ = static_cast<int>(*(pad->begin() + 3));
} }
if (!pooling_param->global_) {
auto kernel_size = pooling_primitive->kernel_size();
if (!param->global_) {
auto kernel_size = value->kernel_size();
if (kernel_size == nullptr) { if (kernel_size == nullptr) {
MS_LOG(ERROR) << "kernel_size is nullptr"; MS_LOG(ERROR) << "kernel_size is nullptr";
free(param);
return nullptr; return nullptr;
} }
pooling_param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
pooling_param->window_h_ = static_cast<int>(*(kernel_size->begin()));
param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
param->window_h_ = static_cast<int>(*(kernel_size->begin()));
} }


auto round_mode = pooling_primitive->round_mode();
auto round_mode = value->round_mode();
switch (round_mode) { switch (round_mode) {
case schema::RoundMode_FLOOR: case schema::RoundMode_FLOOR:
pooling_param->round_mode_ = RoundMode_Floor;
param->round_mode_ = RoundMode_Floor;
break; break;
case schema::RoundMode_CEIL: case schema::RoundMode_CEIL:
pooling_param->round_mode_ = RoundMode_Ceil;
param->round_mode_ = RoundMode_Ceil;
break; break;
default: default:
pooling_param->round_mode_ = RoundMode_No;
param->round_mode_ = RoundMode_No;
break; break;
} }


if (pooling_primitive->activation_type() == schema::ActivationType_RELU) {
pooling_param->act_type_ = ActType_Relu;
} else if (pooling_primitive->activation_type() == schema::ActivationType_RELU6) {
pooling_param->act_type_ = ActType_Relu6;
if (value->activation_type() == schema::ActivationType_RELU) {
param->act_type_ = ActType_Relu;
} else if (value->activation_type() == schema::ActivationType_RELU6) {
param->act_type_ = ActType_Relu6;
} else { } else {
pooling_param->act_type_ = ActType_No;
param->act_type_ = ActType_No;
} }


switch (pooling_primitive->pad_mode()) {
switch (value->pad_mode()) {
case schema::PadMode_SAME: case schema::PadMode_SAME:
pooling_param->pad_mode_ = Pad_same;
param->pad_mode_ = Pad_same;
break; break;
case schema::PadMode_VALID: case schema::PadMode_VALID:
pooling_param->pad_mode_ = Pad_valid;
param->pad_mode_ = Pad_valid;
break; break;
default: default:
pooling_param->pad_mode_ = Pad_pad;
param->pad_mode_ = Pad_pad;
break; break;
} }
return reinterpret_cast<OpParameter *>(pooling_param);
return reinterpret_cast<OpParameter *>(param);
} }


OpParameter *PopulateMaxPoolParameter(const void *primitive) { OpParameter *PopulateMaxPoolParameter(const void *primitive) {
auto *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
if (pooling_param == nullptr) {
MS_LOG(ERROR) << "malloc PoolingParameter failed.";
return nullptr;
}
memset(pooling_param, 0, sizeof(PoolingParameter));
auto pooling_prim = static_cast<const schema::Primitive *>(primitive); auto pooling_prim = static_cast<const schema::Primitive *>(primitive);
MS_ASSERT(pooling_prim != nullptr); MS_ASSERT(pooling_prim != nullptr);
pooling_param->op_parameter_.type_ = pooling_prim->value_type();
auto max_pool_prim = pooling_prim->value_as_MaxPoolFusion();
if (max_pool_prim == nullptr) {
MS_LOG(ERROR) << "max_pool_prim is nullptr";
auto value = pooling_prim->value_as_MaxPoolFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PoolingParameter failed.";
return nullptr; return nullptr;
} }
pooling_param->pool_mode_ = PoolMode_MaxPool;
pooling_param->global_ = max_pool_prim->global();
if (!pooling_param->global_) {
auto kernel_size = max_pool_prim->kernel_size();
auto strides = max_pool_prim->strides();
memset(param, 0, sizeof(PoolingParameter));

param->op_parameter_.type_ = pooling_prim->value_type();
param->pool_mode_ = PoolMode_MaxPool;
param->global_ = value->global();
if (!param->global_) {
auto kernel_size = value->kernel_size();
auto strides = value->strides();
if (kernel_size == nullptr || strides == nullptr) { if (kernel_size == nullptr || strides == nullptr) {
MS_LOG(ERROR) << "kernel_size or strides is nullptr"; MS_LOG(ERROR) << "kernel_size or strides is nullptr";
free(param);
return nullptr; return nullptr;
} }
pooling_param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
pooling_param->window_h_ = static_cast<int>(*(kernel_size->begin()));
pooling_param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
pooling_param->stride_h_ = static_cast<int>(*(strides->begin()));
auto pad = max_pool_prim->pad();
param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
param->window_h_ = static_cast<int>(*(kernel_size->begin()));
param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
param->stride_h_ = static_cast<int>(*(strides->begin()));
auto pad = value->pad();
if (pad != nullptr) { if (pad != nullptr) {
pooling_param->pad_u_ = static_cast<int>(*(pad->begin()));
pooling_param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
pooling_param->pad_l_ = static_cast<int>(*(pad->begin() + 2));
pooling_param->pad_r_ = static_cast<int>(*(pad->begin() + 3));
param->pad_u_ = static_cast<int>(*(pad->begin()));
param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
param->pad_l_ = static_cast<int>(*(pad->begin() + 2));
param->pad_r_ = static_cast<int>(*(pad->begin() + 3));
} }
} }


auto round_mode = max_pool_prim->round_mode();
auto round_mode = value->round_mode();
switch (round_mode) { switch (round_mode) {
case schema::RoundMode_FLOOR: case schema::RoundMode_FLOOR:
pooling_param->round_mode_ = RoundMode_Floor;
param->round_mode_ = RoundMode_Floor;
break; break;
case schema::RoundMode_CEIL: case schema::RoundMode_CEIL:
pooling_param->round_mode_ = RoundMode_Ceil;
param->round_mode_ = RoundMode_Ceil;
break; break;
default: default:
pooling_param->round_mode_ = RoundMode_No;
param->round_mode_ = RoundMode_No;
break; break;
} }


if (max_pool_prim->activation_type() == schema::ActivationType_RELU) {
pooling_param->act_type_ = ActType_Relu;
} else if (max_pool_prim->activation_type() == schema::ActivationType_RELU6) {
pooling_param->act_type_ = ActType_Relu6;
if (value->activation_type() == schema::ActivationType_RELU) {
param->act_type_ = ActType_Relu;
} else if (value->activation_type() == schema::ActivationType_RELU6) {
param->act_type_ = ActType_Relu6;
} else { } else {
pooling_param->act_type_ = ActType_No;
param->act_type_ = ActType_No;
} }


switch (max_pool_prim->pad_mode()) {
switch (value->pad_mode()) {
case schema::PadMode_SAME: case schema::PadMode_SAME:
pooling_param->pad_mode_ = Pad_same;
param->pad_mode_ = Pad_same;
break; break;
case schema::PadMode_VALID: case schema::PadMode_VALID:
pooling_param->pad_mode_ = Pad_valid;
param->pad_mode_ = Pad_valid;
break; break;
default: default:
pooling_param->pad_mode_ = Pad_pad;
param->pad_mode_ = Pad_pad;
break; break;
} }
return reinterpret_cast<OpParameter *>(pooling_param);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR)


+ 15
- 15
mindspore/lite/src/ops/populate/power_populate.cc View File

@@ -19,27 +19,27 @@ using mindspore::schema::PrimitiveType_PowFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulatePowerParameter(const void *prim) { OpParameter *PopulatePowerParameter(const void *prim) {
auto *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
if (power_param == nullptr) {
MS_LOG(ERROR) << "malloc PowerParameter failed.";
return nullptr;
}
memset(power_param, 0, sizeof(PowerParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
power_param->op_parameter_.type_ = primitive->value_type();
auto power_prim = primitive->value_as_PowFusion();
if (power_prim == nullptr) {
MS_LOG(ERROR) << "power_prim is nullptr";
auto value = primitive->value_as_PowFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PowerParameter failed.";
return nullptr; return nullptr;
} }
power_param->scale_ = power_prim->scale();
power_param->shift_ = power_prim->shift();
return reinterpret_cast<OpParameter *>(power_param);
memset(param, 0, sizeof(PowerParameter));

param->op_parameter_.type_ = primitive->value_type();
param->scale_ = value->scale();
param->shift_ = value->shift();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_PowFusion, PopulatePowerParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_PowFusion, PopulatePowerParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 33
- 27
mindspore/lite/src/ops/populate/prior_box_populate.cc View File

@@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_PriorBox;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulatePriorBoxParameter(const void *prim) { OpParameter *PopulatePriorBoxParameter(const void *prim) {
auto *prior_box_param = reinterpret_cast<PriorBoxParameter *>(malloc(sizeof(PriorBoxParameter)));
if (prior_box_param == nullptr) {
MS_LOG(ERROR) << "malloc PriorBoxParameter failed.";
return nullptr;
}
memset(prior_box_param, 0, sizeof(PriorBoxParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_PriorBox(); auto value = primitive->value_as_PriorBox();
@@ -34,67 +27,80 @@ OpParameter *PopulatePriorBoxParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
prior_box_param->op_parameter_.type_ = primitive->value_type();

auto *param = reinterpret_cast<PriorBoxParameter *>(malloc(sizeof(PriorBoxParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PriorBoxParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(PriorBoxParameter));

param->op_parameter_.type_ = primitive->value_type();
auto min_sizes = value->min_sizes(); auto min_sizes = value->min_sizes();
if (min_sizes == nullptr) { if (min_sizes == nullptr) {
MS_LOG(ERROR) << "min_sizes is nullptr"; MS_LOG(ERROR) << "min_sizes is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (min_sizes->size() > MAX_SHAPE_SIZE) { if (min_sizes->size() > MAX_SHAPE_SIZE) {
MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << min_sizes->size(); MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << min_sizes->size();
free(prior_box_param);
free(param);
return nullptr; return nullptr;
} }
prior_box_param->min_sizes_size = min_sizes->size();
memcpy(prior_box_param->min_sizes, min_sizes->data(), min_sizes->size() * sizeof(int32_t));
param->min_sizes_size = min_sizes->size();
memcpy(param->min_sizes, min_sizes->data(), min_sizes->size() * sizeof(int32_t));


auto max_sizes = value->max_sizes(); auto max_sizes = value->max_sizes();
if (max_sizes == nullptr) { if (max_sizes == nullptr) {
MS_LOG(ERROR) << "max_sizes is nullptr"; MS_LOG(ERROR) << "max_sizes is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (max_sizes->size() > MAX_SHAPE_SIZE) { if (max_sizes->size() > MAX_SHAPE_SIZE) {
MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << max_sizes->size(); MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << max_sizes->size();
free(prior_box_param);
free(param);
return nullptr; return nullptr;
} }
prior_box_param->max_sizes_size = max_sizes->size();
memcpy(prior_box_param->max_sizes, max_sizes->data(), max_sizes->size() * sizeof(int32_t));
param->max_sizes_size = max_sizes->size();
memcpy(param->max_sizes, max_sizes->data(), max_sizes->size() * sizeof(int32_t));


auto aspect_ratios = value->aspect_ratios(); auto aspect_ratios = value->aspect_ratios();
if (aspect_ratios == nullptr) { if (aspect_ratios == nullptr) {
MS_LOG(ERROR) << "aspect_ratios is nullptr"; MS_LOG(ERROR) << "aspect_ratios is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (aspect_ratios->size() > MAX_SHAPE_SIZE) { if (aspect_ratios->size() > MAX_SHAPE_SIZE) {
MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << MAX_SHAPE_SIZE << ", got " MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << MAX_SHAPE_SIZE << ", got "
<< aspect_ratios->size(); << aspect_ratios->size();
free(prior_box_param);
free(param);
return nullptr; return nullptr;
} }
prior_box_param->aspect_ratios_size = aspect_ratios->size();
memcpy(prior_box_param->aspect_ratios, aspect_ratios->data(), aspect_ratios->size() * sizeof(float));
param->aspect_ratios_size = aspect_ratios->size();
memcpy(param->aspect_ratios, aspect_ratios->data(), aspect_ratios->size() * sizeof(float));


auto variances = value->variances(); auto variances = value->variances();
if (variances == nullptr) { if (variances == nullptr) {
MS_LOG(ERROR) << "variances is nullptr"; MS_LOG(ERROR) << "variances is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (variances->size() != COMM_SHAPE_SIZE) { if (variances->size() != COMM_SHAPE_SIZE) {
MS_LOG(ERROR) << "PriorBox variances size should be " << COMM_SHAPE_SIZE << ", got " << variances->size(); MS_LOG(ERROR) << "PriorBox variances size should be " << COMM_SHAPE_SIZE << ", got " << variances->size();
free(prior_box_param);
free(param);
return nullptr; return nullptr;
} }
memcpy(prior_box_param->variances, variances->data(), COMM_SHAPE_SIZE * sizeof(float));
prior_box_param->flip = value->flip();
prior_box_param->clip = value->clip();
prior_box_param->offset = value->offset();
prior_box_param->image_size_h = value->image_size_h();
prior_box_param->image_size_w = value->image_size_w();
prior_box_param->step_h = value->step_h();
prior_box_param->step_w = value->step_w();
return reinterpret_cast<OpParameter *>(prior_box_param);
memcpy(param->variances, variances->data(), COMM_SHAPE_SIZE * sizeof(float));
param->flip = value->flip();
param->clip = value->clip();
param->offset = value->offset();
param->image_size_h = value->image_size_h();
param->image_size_w = value->image_size_w();
param->step_h = value->step_h();
param->step_w = value->step_w();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_PriorBox, PopulatePriorBoxParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_PriorBox, PopulatePriorBoxParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 13
- 11
mindspore/lite/src/ops/populate/quant_dtype_cast_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_QuantDTypeCast;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateQuantDTypeCastParameter(const void *prim) { OpParameter *PopulateQuantDTypeCastParameter(const void *prim) {
auto *parameter = reinterpret_cast<QuantDTypeCastParameter *>(malloc(sizeof(QuantDTypeCastParameter)));
if (parameter == nullptr) {
MS_LOG(ERROR) << "malloc QuantDTypeCastParameter failed.";
return nullptr;
}
memset(parameter, 0, sizeof(QuantDTypeCastParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_QuantDTypeCast(); auto value = primitive->value_as_QuantDTypeCast();
@@ -33,12 +27,20 @@ OpParameter *PopulateQuantDTypeCastParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->value_type();
parameter->srcT = value->src_t();
parameter->dstT = value->dst_t();
return reinterpret_cast<OpParameter *>(parameter);

auto *param = reinterpret_cast<QuantDTypeCastParameter *>(malloc(sizeof(QuantDTypeCastParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc QuantDTypeCastParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(QuantDTypeCastParameter));

param->op_parameter_.type_ = primitive->value_type();
param->srcT = value->src_t();
param->dstT = value->dst_t();
return reinterpret_cast<OpParameter *>(param);
} }
REG_POPULATE(PrimitiveType_QuantDTypeCast, PopulateQuantDTypeCastParameter, SCHEMA_CUR);


REG_POPULATE(PrimitiveType_QuantDTypeCast, PopulateQuantDTypeCastParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 15
- 12
mindspore/lite/src/ops/populate/random_standard_normal_populate.cc View File

@@ -22,25 +22,28 @@ namespace mindspore {
namespace lite { namespace lite {
namespace { namespace {
OpParameter *PopulateRandomStandardNormalParameter(const void *prim) { OpParameter *PopulateRandomStandardNormalParameter(const void *prim) {
auto *random_parameter = reinterpret_cast<RandomParam *>(malloc(sizeof(RandomParam)));
if (random_parameter == nullptr) {
MS_LOG(ERROR) << "malloc RandomStandardNormal parameter failed.";
return nullptr;
}
memset(random_parameter, 0, sizeof(RandomParam));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
random_parameter->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_RandomStandardNormal();
auto value = primitive->value_as_RandomStandardNormal();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<RandomParam *>(malloc(sizeof(RandomParam)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc RandomParam failed.";
return nullptr; return nullptr;
} }
random_parameter->seed_ = param->seed();
random_parameter->seed2_ = param->seed2();
return reinterpret_cast<OpParameter *>(random_parameter);
memset(param, 0, sizeof(RandomParam));

param->op_parameter_.type_ = primitive->value_type();
param->seed_ = value->seed();
param->seed2_ = value->seed2();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace } // namespace

REG_POPULATE(PrimitiveType_RandomStandardNormal, PopulateRandomStandardNormalParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_RandomStandardNormal, PopulateRandomStandardNormalParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 18
- 17
mindspore/lite/src/ops/populate/range_populate.cc View File

@@ -19,29 +19,30 @@ using mindspore::schema::PrimitiveType_Range;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateRangeParameter(const void *prim) { OpParameter *PopulateRangeParameter(const void *prim) {
auto *range_param = reinterpret_cast<RangeParameter *>(malloc(sizeof(RangeParameter)));
if (range_param == nullptr) {
MS_LOG(ERROR) << "malloc RangeParameter failed.";
return nullptr;
}
memset(range_param, 0, sizeof(RangeParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
range_param->op_parameter_.type_ = primitive->value_type();
auto range_prim = primitive->value_as_Range();
if (range_prim == nullptr) {
MS_LOG(ERROR) << "range_prim is nullptr";
auto value = primitive->value_as_Range();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
range_param->start_ = range_prim->start();
range_param->limit_ = range_prim->limit();
range_param->delta_ = range_prim->delta();
range_param->dType_ = range_prim->d_type();
return reinterpret_cast<OpParameter *>(range_param);

auto *param = reinterpret_cast<RangeParameter *>(malloc(sizeof(RangeParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc RangeParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(RangeParameter));

param->op_parameter_.type_ = primitive->value_type();
param->start_ = value->start();
param->limit_ = value->limit();
param->delta_ = value->delta();
param->dType_ = value->d_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_Range, PopulateRangeParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Range, PopulateRangeParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 9
- 9
mindspore/lite/src/ops/populate/rank_populate.cc View File

@@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Rank;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateRankParameter(const void *prim) { OpParameter *PopulateRankParameter(const void *prim) {
auto *rank_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (rank_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc RankParameter failed."; MS_LOG(ERROR) << "malloc RankParameter failed.";
return nullptr; return nullptr;
} }
memset(rank_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
rank_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(rank_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Rank, PopulateRankParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Rank, PopulateRankParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 15
- 14
mindspore/lite/src/ops/populate/reduce_populate.cc View File

@@ -13,19 +13,13 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "nnacl/reduce_parameter.h" #include "nnacl/reduce_parameter.h"
using mindspore::schema::PrimitiveType_ReduceFusion; using mindspore::schema::PrimitiveType_ReduceFusion;

namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateReduceParameter(const void *prim) { OpParameter *PopulateReduceParameter(const void *prim) {
auto *reduce_param = reinterpret_cast<ReduceParameter *>(malloc(sizeof(ReduceParameter)));
if (reduce_param == nullptr) {
MS_LOG(ERROR) << "malloc ReduceParameter failed.";
return nullptr;
}
memset(reduce_param, 0, sizeof(ReduceParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ReduceFusion(); auto value = primitive->value_as_ReduceFusion();
@@ -33,15 +27,22 @@ OpParameter *PopulateReduceParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
reduce_param->op_parameter_.type_ = primitive->value_type();
reduce_param->keep_dims_ = value->keep_dims();
reduce_param->reduce_to_end_ = value->reduce_to_end();
reduce_param->coeff = value->coeff();
reduce_param->mode_ = static_cast<int>(value->mode());
return reinterpret_cast<OpParameter *>(reduce_param);

auto *param = reinterpret_cast<ReduceParameter *>(malloc(sizeof(ReduceParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ReduceParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ReduceParameter));

param->op_parameter_.type_ = primitive->value_type();
param->keep_dims_ = value->keep_dims();
param->reduce_to_end_ = value->reduce_to_end();
param->coeff = value->coeff();
param->mode_ = static_cast<int>(value->mode());
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_ReduceFusion, PopulateReduceParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ReduceFusion, PopulateReduceParameter, SCHEMA_CUR)

} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 9
- 9
mindspore/lite/src/ops/populate/reshape_populate.cc View File

@@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_Reshape;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateReshapeParameter(const void *prim) { OpParameter *PopulateReshapeParameter(const void *prim) {
auto *reshape_param = reinterpret_cast<ReshapeParameter *>(malloc(sizeof(ReshapeParameter)));
if (reshape_param == nullptr) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<ReshapeParameter *>(malloc(sizeof(ReshapeParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ReshapeParameter failed."; MS_LOG(ERROR) << "malloc ReshapeParameter failed.";
return nullptr; return nullptr;
} }
memset(reshape_param, 0, sizeof(ReshapeParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
reshape_param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(reshape_param);
memset(param, 0, sizeof(ReshapeParameter));

param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Reshape, PopulateReshapeParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Reshape, PopulateReshapeParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 15
- 15
mindspore/lite/src/ops/populate/resize_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_Resize;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateResizeParameter(const void *prim) { OpParameter *PopulateResizeParameter(const void *prim) {
auto *resize_param = reinterpret_cast<ResizeParameter *>(malloc(sizeof(ResizeParameter)));
if (resize_param == nullptr) {
MS_LOG(ERROR) << "malloc ResizeParameter failed.";
return nullptr;
}
memset(resize_param, 0, sizeof(ResizeParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Resize(); auto value = primitive->value_as_Resize();
@@ -33,18 +27,24 @@ OpParameter *PopulateResizeParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
resize_param->op_parameter_.type_ = primitive->value_type();


resize_param->method_ = static_cast<int>(value->method());
resize_param->new_height_ = value->new_height();
resize_param->new_width_ = value->new_width();
resize_param->coordinate_transform_mode_ = value->coordinate_transform_mode();
resize_param->preserve_aspect_ratio_ = value->preserve_aspect_ratio();
resize_param->cubic_coeff_ = value->cubic_coeff();
return reinterpret_cast<OpParameter *>(resize_param);
auto *param = reinterpret_cast<ResizeParameter *>(malloc(sizeof(ResizeParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ResizeParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ResizeParameter));

param->op_parameter_.type_ = primitive->value_type();
param->method_ = static_cast<int>(value->method());
param->new_height_ = value->new_height();
param->new_width_ = value->new_width();
param->coordinate_transform_mode_ = value->coordinate_transform_mode();
param->preserve_aspect_ratio_ = value->preserve_aspect_ratio();
param->cubic_coeff_ = value->cubic_coeff();
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_Resize, PopulateResizeParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Resize, PopulateResizeParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite

} // namespace mindspore } // namespace mindspore

+ 13
- 11
mindspore/lite/src/ops/populate/reverse_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_ReverseV2;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateReverseParameter(const void *prim) { OpParameter *PopulateReverseParameter(const void *prim) {
auto *reverse_param = reinterpret_cast<ReverseParameter *>(malloc(sizeof(ReverseParameter)));
if (reverse_param == nullptr) {
MS_LOG(ERROR) << "malloc ReverseParameter failed.";
return nullptr;
}
memset(reverse_param, 0, sizeof(ReverseParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ReverseV2(); auto value = primitive->value_as_ReverseV2();
@@ -33,19 +27,27 @@ OpParameter *PopulateReverseParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
reverse_param->op_parameter_.type_ = primitive->value_type();


auto *param = reinterpret_cast<ReverseParameter *>(malloc(sizeof(ReverseParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ReverseParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ReverseParameter));

param->op_parameter_.type_ = primitive->value_type();
auto flatAxis = value->axis(); auto flatAxis = value->axis();
if (flatAxis == nullptr) { if (flatAxis == nullptr) {
MS_LOG(ERROR) << "flatAxis is nullptr"; MS_LOG(ERROR) << "flatAxis is nullptr";
free(param);
return nullptr; return nullptr;
} }
reverse_param->num_axis_ = flatAxis->size();
param->num_axis_ = flatAxis->size();
int i = 0; int i = 0;
for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) {
reverse_param->axis_[i++] = *iter;
for (auto flatAxi : *flatAxis) {
param->axis_[i++] = static_cast<int>(flatAxi);
} }
return reinterpret_cast<OpParameter *>(reverse_param);
return reinterpret_cast<OpParameter *>(param);
} }


REG_POPULATE(PrimitiveType_ReverseV2, PopulateReverseParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ReverseV2, PopulateReverseParameter, SCHEMA_CUR)


+ 14
- 15
mindspore/lite/src/ops/populate/reverse_sequence_populate.cc View File

@@ -19,29 +19,28 @@ using mindspore::schema::PrimitiveType_ReverseSequence;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateReverseSequenceParameter(const void *prim) { OpParameter *PopulateReverseSequenceParameter(const void *prim) {
auto *reverse_sequence_param = reinterpret_cast<ReverseSequenceParameter *>(malloc(sizeof(ReverseSequenceParameter)));
if (reverse_sequence_param == nullptr) {
MS_LOG(ERROR) << "malloc ReverseSequenceParameter failed.";
return nullptr;
}
memset(reverse_sequence_param, 0, sizeof(ReverseSequenceParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto param = primitive->value_as_ReverseSequence();
auto value = primitive->value_as_ReverseSequence();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<ReverseSequenceParameter *>(malloc(sizeof(ReverseSequenceParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc ReverseSequenceParameter failed.";
return nullptr; return nullptr;
} }
reverse_sequence_param->op_parameter_.type_ = primitive->value_type();
reverse_sequence_param->seq_axis_ = static_cast<int>(param->seq_dim());
reverse_sequence_param->batch_axis_ = static_cast<int>(param->batch_dim());
return reinterpret_cast<OpParameter *>(reverse_sequence_param);
memset(param, 0, sizeof(ReverseSequenceParameter));

param->op_parameter_.type_ = primitive->value_type();
param->seq_axis_ = static_cast<int>(value->seq_dim());
param->batch_axis_ = static_cast<int>(value->batch_dim());
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_ReverseSequence, PopulateReverseSequenceParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_ReverseSequence, PopulateReverseSequenceParameter, SCHEMA_CUR);

} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 15
- 16
mindspore/lite/src/ops/populate/roi_pooling_populate.cc View File

@@ -19,29 +19,28 @@ using mindspore::schema::PrimitiveType_ROIPooling;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateROIPoolingParameter(const void *prim) { OpParameter *PopulateROIPoolingParameter(const void *prim) {
auto *roi_param = reinterpret_cast<ROIPoolingParameter *>(malloc(sizeof(ROIPoolingParameter)));
if (roi_param == nullptr) {
MS_LOG(ERROR) << "malloc ROIPoolingParameter failed.";
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ROIPooling();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }


memset(roi_param, 0, sizeof(ROIPoolingParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
roi_param->op_parameter_.type_ = primitive->value_type();
auto roi_prim = primitive->value_as_ROIPooling();
if (roi_prim == nullptr) {
MS_LOG(ERROR) << "roi_prim is nullptr";
auto *param = reinterpret_cast<ROIPoolingParameter *>(malloc(sizeof(ROIPoolingParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ROIPoolingParameter failed.";
return nullptr; return nullptr;
} }
roi_param->pooledH_ = roi_prim->pooled_h();
roi_param->pooledW_ = roi_prim->pooled_w();
roi_param->scale_ = roi_prim->scale();
return reinterpret_cast<OpParameter *>(roi_param);
memset(param, 0, sizeof(ROIPoolingParameter));

param->op_parameter_.type_ = primitive->value_type();
param->pooledH_ = value->pooled_h();
param->pooledW_ = value->pooled_w();
param->scale_ = value->scale();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_ROIPooling, PopulateROIPoolingParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ROIPooling, PopulateROIPoolingParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 12
- 12
mindspore/lite/src/ops/populate/scale_populate.cc View File

@@ -19,14 +19,7 @@ using mindspore::schema::PrimitiveType_ScaleFusion;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateScaleParameter(const void *prim) { OpParameter *PopulateScaleParameter(const void *prim) {
auto *scale_param = reinterpret_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter)));
if (scale_param == nullptr) {
MS_LOG(ERROR) << "malloc ScaleParameter failed.";
return nullptr;
}
memset(scale_param, 0, sizeof(ScaleParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ScaleFusion(); auto value = primitive->value_as_ScaleFusion();
@@ -34,12 +27,19 @@ OpParameter *PopulateScaleParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
scale_param->op_parameter_.type_ = primitive->value_type();
scale_param->axis_ = value->axis();
scale_param->activation_type_ = value->activation_type();
return reinterpret_cast<OpParameter *>(scale_param);

auto *param = reinterpret_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ScaleParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ScaleParameter));

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = value->axis();
param->activation_type_ = value->activation_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_ScaleFusion, PopulateScaleParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ScaleFusion, PopulateScaleParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 10
- 9
mindspore/lite/src/ops/populate/scatter_nd_populate.cc View File

@@ -18,20 +18,21 @@ using mindspore::schema::PrimitiveType_ScatterNd;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateScatterNDParameter(const void *prim) { OpParameter *PopulateScatterNDParameter(const void *prim) {
auto *scatter_nd_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (scatter_nd_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ScatterNDParameter failed."; MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
return nullptr; return nullptr;
} }
memset(scatter_nd_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
scatter_nd_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(scatter_nd_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_ScatterNd, PopulateScatterNDParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_ScatterNd, PopulateScatterNDParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 9
- 9
mindspore/lite/src/ops/populate/shape_populate.cc View File

@@ -20,20 +20,20 @@ using mindspore::schema::PrimitiveType_Shape;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateShapeParameter(const void *prim) { OpParameter *PopulateShapeParameter(const void *prim) {
auto *shape_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (shape_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ShapeParameter failed."; MS_LOG(ERROR) << "malloc ShapeParameter failed.";
return nullptr; return nullptr;
} }
memset(shape_param, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
shape_param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(shape_param);
memset(param, 0, sizeof(OpParameter));

param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Shape, PopulateShapeParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Shape, PopulateShapeParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 14
- 11
mindspore/lite/src/ops/populate/skip_gram_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_SkipGram;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateSkipGramParameter(const void *prim) { OpParameter *PopulateSkipGramParameter(const void *prim) {
auto *skipGramParameter = reinterpret_cast<SkipGramParameter *>(malloc(sizeof(SkipGramParameter)));
if (skipGramParameter == nullptr) {
MS_LOG(ERROR) << "malloc SkipGramParameter failed.";
return nullptr;
}
memset(skipGramParameter, 0, sizeof(SkipGramParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_SkipGram(); auto value = primitive->value_as_SkipGram();
@@ -33,12 +27,21 @@ OpParameter *PopulateSkipGramParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
skipGramParameter->op_parameter_.type_ = primitive->value_type();
skipGramParameter->ngram_size = value->ngram_size();
skipGramParameter->max_skip_size = value->max_skip_size();
skipGramParameter->include_all_ngrams = value->include_all_grams();
return reinterpret_cast<OpParameter *>(skipGramParameter);

auto *param = reinterpret_cast<SkipGramParameter *>(malloc(sizeof(SkipGramParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SkipGramParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SkipGramParameter));

param->op_parameter_.type_ = primitive->value_type();
param->ngram_size = value->ngram_size();
param->max_skip_size = value->max_skip_size();
param->include_all_ngrams = value->include_all_grams();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_SkipGram, PopulateSkipGramParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_SkipGram, PopulateSkipGramParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 13
- 9
mindspore/lite/src/ops/populate/slice_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_SliceFusion;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateSliceParameter(const void *prim) { OpParameter *PopulateSliceParameter(const void *prim) {
auto *slice_param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter)));
if (slice_param == nullptr) {
MS_LOG(ERROR) << "malloc SliceParameter failed.";
return nullptr;
}
memset(slice_param, 0, sizeof(SliceParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_SliceFusion(); auto value = primitive->value_as_SliceFusion();
@@ -33,17 +27,27 @@ OpParameter *PopulateSliceParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
slice_param->op_parameter_.type_ = primitive->value_type();

auto *param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SliceParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SliceParameter));

param->op_parameter_.type_ = primitive->value_type();
auto axes = value->axes(); auto axes = value->axes();
if (axes == nullptr) { if (axes == nullptr) {
MS_LOG(ERROR) << "axes is nullptr"; MS_LOG(ERROR) << "axes is nullptr";
free(param);
return nullptr; return nullptr;
} }
for (size_t i = 0; i < axes->size(); ++i) { for (size_t i = 0; i < axes->size(); ++i) {
slice_param->axis_[i] = axes->Get(i);
param->axis_[i] = axes->Get(i);
} }
return reinterpret_cast<OpParameter *>(slice_param);
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_SliceFusion, PopulateSliceParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_SliceFusion, PopulateSliceParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 17
- 16
mindspore/lite/src/ops/populate/softmax_populate.cc View File

@@ -19,36 +19,37 @@ using mindspore::schema::PrimitiveType_Softmax;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateSoftmaxParameter(const void *prim) { OpParameter *PopulateSoftmaxParameter(const void *prim) {
auto *softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter)));
if (softmax_param == nullptr) {
MS_LOG(ERROR) << "malloc SoftmaxParameter failed.";
return nullptr;
}
memset(softmax_param, 0, sizeof(SoftmaxParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
softmax_param->op_parameter_.type_ = primitive->value_type();
auto prim_softmax = primitive->value_as_Softmax();
if (prim_softmax == nullptr) {
MS_LOG(ERROR) << "prim_softmax is nullptr";
auto value = primitive->value_as_Softmax();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SoftmaxParameter failed.";
return nullptr; return nullptr;
} }
auto axis = prim_softmax->axis();
memset(param, 0, sizeof(SoftmaxParameter));

param->op_parameter_.type_ = primitive->value_type();
auto axis = value->axis();
if (axis == nullptr) { if (axis == nullptr) {
MS_LOG(ERROR) << "axis is nullptr"; MS_LOG(ERROR) << "axis is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (axis->size() != 1) { if (axis->size() != 1) {
MS_LOG(ERROR) << "axis number invalid!number: " << axis->size(); MS_LOG(ERROR) << "axis number invalid!number: " << axis->size();
free(softmax_param);
free(param);
return nullptr; return nullptr;
} }
softmax_param->axis_ = axis->data()[0];
return reinterpret_cast<OpParameter *>(softmax_param);
param->axis_ = axis->data()[0];
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_Softmax, PopulateSoftmaxParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Softmax, PopulateSoftmaxParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite


+ 27
- 24
mindspore/lite/src/ops/populate/space_to_batch_nd_populate.cc View File

@@ -19,42 +19,45 @@ using mindspore::schema::PrimitiveType_SpaceToBatchND;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) { OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
auto *space_batch_param_nd = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter)));
if (space_batch_param_nd == nullptr) {
MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed.";
return nullptr;
}
memset(space_batch_param_nd, 0, sizeof(SpaceToBatchParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
space_batch_param_nd->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_SpaceToBatchND();
auto value = primitive->value_as_SpaceToBatchND();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}

auto *param = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed.";
return nullptr; return nullptr;
} }
auto block_shape = param->block_shape();
memset(param, 0, sizeof(SpaceToBatchParameter));

param->op_parameter_.type_ = primitive->value_type();
auto block_shape = value->block_shape();
if (block_shape == nullptr) { if (block_shape == nullptr) {
return reinterpret_cast<OpParameter *>(space_batch_param_nd);
return reinterpret_cast<OpParameter *>(param);
} }
auto block_shapes = std::vector<int64_t>(block_shape->begin(), block_shape->end()); auto block_shapes = std::vector<int64_t>(block_shape->begin(), block_shape->end());
if (block_shapes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) { if (block_shapes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) {
MS_LOG(ERROR) << "The value of block_shapes.size() is too big"; MS_LOG(ERROR) << "The value of block_shapes.size() is too big";
free(space_batch_param_nd);
free(param);
return nullptr; return nullptr;
} }
space_batch_param_nd->m_ = block_shapes.size();
param->m_ = block_shapes.size();


auto param_paddings = param->paddings();
auto param_paddings = value->paddings();
if (param_paddings == nullptr) { if (param_paddings == nullptr) {
MS_LOG(ERROR) << "param_paddings is nullptr"; MS_LOG(ERROR) << "param_paddings is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto fb_paddings = param_paddings->data(); auto fb_paddings = param_paddings->data();
if (fb_paddings == nullptr) { if (fb_paddings == nullptr) {
MS_LOG(ERROR) << "fb_paddings is nullptr"; MS_LOG(ERROR) << "fb_paddings is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (fb_paddings->size() == 0 || if (fb_paddings->size() == 0 ||
@@ -62,14 +65,15 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) > static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) >
std::numeric_limits<size_t>::max() / sizeof(int64_t))) { std::numeric_limits<size_t>::max() / sizeof(int64_t))) {
MS_LOG(ERROR) << "The value of paddings.size() is zero or too big"; MS_LOG(ERROR) << "The value of paddings.size() is zero or too big";
free(space_batch_param_nd);
free(param);
return nullptr; return nullptr;
} }
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
for (auto iter = fb_paddings->begin(); iter != fb_paddings->end(); ++iter) {
auto paddings_data = (*iter)->data();
for (auto fb_padding : *fb_paddings) {
auto paddings_data = fb_padding->data();
if (paddings_data == nullptr) { if (paddings_data == nullptr) {
MS_LOG(ERROR) << "paddings_data is nullptr"; MS_LOG(ERROR) << "paddings_data is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end()); auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end());
@@ -77,17 +81,16 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
} }


for (size_t i = 0; i < block_shapes.size(); ++i) { for (size_t i = 0; i < block_shapes.size(); ++i) {
space_batch_param_nd->block_sizes_[i] = static_cast<int>(block_shapes[i]);
param->block_sizes_[i] = static_cast<int>(block_shapes[i]);
} }

space_batch_param_nd->m_ = block_shapes.size();
param->m_ = block_shapes.size();


for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
space_batch_param_nd->paddings_[i] = static_cast<int>(paddings[i]);
param->paddings_[i] = static_cast<int>(paddings[i]);
} }
return reinterpret_cast<OpParameter *>(space_batch_param_nd);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_SpaceToBatchND, PopulateSpaceToBatchNDParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_SpaceToBatchND, PopulateSpaceToBatchNDParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 26
- 21
mindspore/lite/src/ops/populate/space_to_batch_populate.cc View File

@@ -19,43 +19,47 @@ using mindspore::schema::PrimitiveType_SpaceToBatch;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateSpaceToBatchParameter(const void *prim) { OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
auto *space_batch_param = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter)));
if (space_batch_param == nullptr) {
MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed.";
return nullptr;
}
memset(space_batch_param, 0, sizeof(SpaceToBatchParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
space_batch_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_SpaceToBatch();
if (param == nullptr) {
auto value = primitive->value_as_SpaceToBatch();
if (value == nullptr) {
MS_LOG(ERROR) << "param is nullptr"; MS_LOG(ERROR) << "param is nullptr";
return nullptr; return nullptr;
} }
auto block_size = param->block_size();

auto *param = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SpaceToBatchParameter));

param->op_parameter_.type_ = primitive->value_type();
auto block_size = value->block_size();
if (block_size == nullptr) { if (block_size == nullptr) {
MS_LOG(ERROR) << "block_size is nullptr"; MS_LOG(ERROR) << "block_size is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto block_sizes = std::vector<int64_t>(block_size->begin(), block_size->end()); auto block_sizes = std::vector<int64_t>(block_size->begin(), block_size->end());
if (block_sizes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) { if (block_sizes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) {
MS_LOG(ERROR) << "The value of block_sizes.size() is too big"; MS_LOG(ERROR) << "The value of block_sizes.size() is too big";
free(space_batch_param);
free(param);
return nullptr; return nullptr;
} }
space_batch_param->m_ = block_sizes.size();
param->m_ = block_sizes.size();


auto param_paddings = param->paddings();
auto param_paddings = value->paddings();
if (param_paddings == nullptr) { if (param_paddings == nullptr) {
MS_LOG(ERROR) << "param_paddings is nullptr"; MS_LOG(ERROR) << "param_paddings is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto fb_paddings = param_paddings->data(); auto fb_paddings = param_paddings->data();
if (fb_paddings == nullptr) { if (fb_paddings == nullptr) {
MS_LOG(ERROR) << "fb_paddings is nullptr"; MS_LOG(ERROR) << "fb_paddings is nullptr";
free(param);
return nullptr; return nullptr;
} }
if (fb_paddings->size() == 0 || if (fb_paddings->size() == 0 ||
@@ -63,14 +67,15 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) > static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) >
std::numeric_limits<size_t>::max() / sizeof(int64_t))) { std::numeric_limits<size_t>::max() / sizeof(int64_t))) {
MS_LOG(ERROR) << "The value of paddings.size() is zero or too big"; MS_LOG(ERROR) << "The value of paddings.size() is zero or too big";
free(space_batch_param);
free(param);
return nullptr; return nullptr;
} }
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
for (auto iter = fb_paddings->begin(); iter != fb_paddings->end(); ++iter) {
auto paddings_data = (*iter)->data();
for (auto fb_padding : *fb_paddings) {
auto paddings_data = fb_padding->data();
if (paddings_data == nullptr) { if (paddings_data == nullptr) {
MS_LOG(ERROR) << "paddings_data is nullptr"; MS_LOG(ERROR) << "paddings_data is nullptr";
free(param);
return nullptr; return nullptr;
} }
auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end()); auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end());
@@ -78,15 +83,15 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
} }


for (size_t i = 0; i < block_sizes.size(); ++i) { for (size_t i = 0; i < block_sizes.size(); ++i) {
space_batch_param->block_sizes_[i] = static_cast<int>(block_sizes[i]);
param->block_sizes_[i] = static_cast<int>(block_sizes[i]);
} }


for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
space_batch_param->paddings_[i] = static_cast<int>(paddings[i]);
param->paddings_[i] = static_cast<int>(paddings[i]);
} }
return reinterpret_cast<OpParameter *>(space_batch_param);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_SpaceToBatch, PopulateSpaceToBatchParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_SpaceToBatch, PopulateSpaceToBatchParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 13
- 10
mindspore/lite/src/ops/populate/space_to_depth_populate.cc View File

@@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_SpaceToDepth;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateSpaceToDepthParameter(const void *prim) { OpParameter *PopulateSpaceToDepthParameter(const void *prim) {
auto *space_depth_param = reinterpret_cast<SpaceToDepthParameter *>(malloc(sizeof(SpaceToDepthParameter)));
if (space_depth_param == nullptr) {
MS_LOG(ERROR) << "malloc SpaceToDepthParameter failed.";
return nullptr;
}
memset(space_depth_param, 0, sizeof(SpaceToDepthParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_SpaceToDepth(); auto value = primitive->value_as_SpaceToDepth();
@@ -33,15 +27,24 @@ OpParameter *PopulateSpaceToDepthParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
space_depth_param->op_parameter_.type_ = primitive->value_type();
space_depth_param->block_size_ = value->block_size();

auto *param = reinterpret_cast<SpaceToDepthParameter *>(malloc(sizeof(SpaceToDepthParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SpaceToDepthParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SpaceToDepthParameter));

param->op_parameter_.type_ = primitive->value_type();
param->block_size_ = value->block_size();
if (value->format() != schema::Format::Format_NHWC) { if (value->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "Currently only NHWC format is supported."; MS_LOG(ERROR) << "Currently only NHWC format is supported.";
free(space_depth_param);
free(param);
return nullptr; return nullptr;
} }
return reinterpret_cast<OpParameter *>(space_depth_param);
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_SpaceToDepth, PopulateSpaceToDepthParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_SpaceToDepth, PopulateSpaceToDepthParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 10
- 8
mindspore/lite/src/ops/populate/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -20,18 +20,20 @@ using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) { OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
auto *softmax_cross_entropy_param_ =
reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
if (softmax_cross_entropy_param_ == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed."; MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
return nullptr; return nullptr;
} }
memset(softmax_cross_entropy_param_, 0, sizeof(SoftmaxCrossEntropyParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
softmax_cross_entropy_param_->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(softmax_cross_entropy_param_);
memset(param, 0, sizeof(SoftmaxCrossEntropyParameter));

param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, PopulateSparseSoftmaxCrossEntropyWithLogitsParameter, REG_POPULATE(PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
SCHEMA_CUR); SCHEMA_CUR);
} // namespace lite } // namespace lite


+ 9
- 9
mindspore/lite/src/ops/populate/sparse_to_dense_populate.cc View File

@@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_SparseToDense;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateSparseToDenseParameter(const void *prim) { OpParameter *PopulateSparseToDenseParameter(const void *prim) {
auto *sparse_to_dense_param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter)));
if (sparse_to_dense_param == nullptr) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);

auto *param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SparseToDenseParameter failed."; MS_LOG(ERROR) << "malloc SparseToDenseParameter failed.";
return nullptr; return nullptr;
} }
memset(sparse_to_dense_param, 0, sizeof(SparseToDenseParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
sparse_to_dense_param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(sparse_to_dense_param);
memset(param, 0, sizeof(SparseToDenseParameter));

param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace


REG_POPULATE(PrimitiveType_SparseToDense, PopulateSparseToDenseParameter, SCHEMA_CUR); REG_POPULATE(PrimitiveType_SparseToDense, PopulateSparseToDenseParameter, SCHEMA_CUR);
} // namespace lite } // namespace lite


+ 36
- 32
mindspore/lite/src/ops/populate/splice_populate.cc View File

@@ -17,71 +17,75 @@
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "nnacl/splice_parameter.h" #include "nnacl/splice_parameter.h"
using mindspore::schema::PrimitiveType_Splice; using mindspore::schema::PrimitiveType_Splice;

namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateSpliceParameter(const void *prim) { OpParameter *PopulateSpliceParameter(const void *prim) {
auto *splice_parameter = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter)));
if (splice_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Splice Parameter failed.";
return nullptr;
}
memset(splice_parameter, 0, sizeof(SpliceParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto splice_primitive = primitive->value_as_Splice();
if (splice_primitive == nullptr) {
MS_LOG(ERROR) << "splice_primitive is nullptr";
auto value = primitive->value_as_Splice();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
splice_parameter->op_parameter_.type_ = primitive->value_type();


auto context = splice_primitive->context();
auto *param = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc Splice Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SpliceParameter));

param->op_parameter_.type_ = primitive->value_type();
auto context = value->context();
if (context == nullptr) { if (context == nullptr) {
MS_LOG(ERROR) << "context is nullptr"; MS_LOG(ERROR) << "context is nullptr";
free(param);
return nullptr; return nullptr;
} }
std::vector<int> primitive_context(context->begin(), context->end()); std::vector<int> primitive_context(context->begin(), context->end());
splice_parameter->context_dim_ = static_cast<int>(primitive_context.size());
param->context_dim_ = static_cast<int>(primitive_context.size());


// malloc && memset for context // malloc && memset for context
splice_parameter->context_ = reinterpret_cast<int *>(malloc(splice_parameter->context_dim_ * sizeof(int)));
if (splice_parameter->context_ == nullptr) {
MS_LOG(ERROR) << "malloc splice_parameter context_ error";
free(splice_parameter);
param->context_ = reinterpret_cast<int *>(malloc(param->context_dim_ * sizeof(int)));
if (param->context_ == nullptr) {
MS_LOG(ERROR) << "malloc param context_ error";
free(param);
return nullptr; return nullptr;
} }
// src_to_dst_row_offset // src_to_dst_row_offset
int src_to_dst_row_offset = INT32_MIN; int src_to_dst_row_offset = INT32_MIN;
memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int));
for (int i = 0; i < splice_parameter->context_dim_; ++i) {
splice_parameter->context_[i] = primitive_context.at(i);
memset(param->context_, 0, param->context_dim_ * sizeof(int));
for (int i = 0; i < param->context_dim_; ++i) {
param->context_[i] = primitive_context.at(i);
src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i)));
} }


auto forward_indexes = splice_primitive->forward_indexes();
auto forward_indexes = value->forward_indexes();
if (forward_indexes == nullptr) { if (forward_indexes == nullptr) {
MS_LOG(ERROR) << "forward_indexes is nullptr"; MS_LOG(ERROR) << "forward_indexes is nullptr";
free(param);
return nullptr; return nullptr;
} }
std::vector<int> primitive_forward_indexes(forward_indexes->begin(), forward_indexes->end()); std::vector<int> primitive_forward_indexes(forward_indexes->begin(), forward_indexes->end());
splice_parameter->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size());
param->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size());


// malloc && memset for forward_indexes // malloc && memset for forward_indexes
splice_parameter->forward_indexes_ =
reinterpret_cast<int *>(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int)));
if (splice_parameter->forward_indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error";
free(splice_parameter->context_);
free(splice_parameter);
param->forward_indexes_ = reinterpret_cast<int *>(malloc(param->forward_indexes_dim_ * sizeof(int)));
if (param->forward_indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc param forward_indexes_ error";
free(param->context_);
free(param);
return nullptr; return nullptr;
} }
memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int));
for (int i = 0; i < splice_parameter->context_dim_; ++i) {
splice_parameter->context_[i] = primitive_context.at(i);
memset(param->forward_indexes_, 0, param->forward_indexes_dim_ * sizeof(int));
for (int i = 0; i < param->context_dim_; ++i) {
param->context_[i] = primitive_context.at(i);
} }
splice_parameter->output_dim_ = splice_primitive->output_dim();
return reinterpret_cast<OpParameter *>(splice_parameter);
param->output_dim_ = value->output_dim();
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 24
- 24
mindspore/lite/src/ops/populate/split_populate.cc View File

@@ -19,15 +19,7 @@ using mindspore::schema::PrimitiveType_Split;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateSplitParameter(const void *prim) { OpParameter *PopulateSplitParameter(const void *prim) {
auto *split_param = reinterpret_cast<SplitParameter *>(malloc(sizeof(SplitParameter)));
if (split_param == nullptr) {
MS_LOG(ERROR) << "malloc SplitParameter failed.";
return nullptr;
}
memset(split_param, 0, sizeof(SplitParameter));

auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Split(); auto value = primitive->value_as_Split();
@@ -35,36 +27,44 @@ OpParameter *PopulateSplitParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
split_param->op_parameter_.type_ = primitive->value_type();
split_param->num_split_ = value->output_num();
if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
MS_LOG(ERROR) << "The value of split_param->num_split_ is too big";
free(split_param);

auto *param = reinterpret_cast<SplitParameter *>(malloc(sizeof(SplitParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SplitParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SplitParameter));

param->op_parameter_.type_ = primitive->value_type();
param->num_split_ = value->output_num();
if (param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
MS_LOG(ERROR) << "The value of param->num_split_ is too big";
free(param);
return nullptr; return nullptr;
} }


/* free split_sizes_ in split op base */ /* free split_sizes_ in split op base */
split_param->split_sizes_ = reinterpret_cast<int *>(malloc(split_param->num_split_ * sizeof(int)));
if (split_param->split_sizes_ == nullptr) {
MS_LOG(ERROR) << "malloc split_param split_sizes_ error";
free(split_param);
param->split_sizes_ = reinterpret_cast<int *>(malloc(param->num_split_ * sizeof(int)));
if (param->split_sizes_ == nullptr) {
MS_LOG(ERROR) << "malloc param split_sizes_ error";
free(param);
return nullptr; return nullptr;
} }
memset(split_param->split_sizes_, 0, split_param->num_split_ * sizeof(int));
memset(param->split_sizes_, 0, param->num_split_ * sizeof(int));
auto split_sizes_vector_ = value->size_splits(); auto split_sizes_vector_ = value->size_splits();
if (split_sizes_vector_ != nullptr) { if (split_sizes_vector_ != nullptr) {
int i = 0; int i = 0;
for (auto iter : *split_sizes_vector_) { for (auto iter : *split_sizes_vector_) {
split_param->split_sizes_[i++] = iter;
param->split_sizes_[i++] = iter;
} }
split_param->split_count_ = split_param->num_split_;
param->split_count_ = param->num_split_;
} else { } else {
split_param->split_count_ = 0;
param->split_count_ = 0;
} }
split_param->split_dim_ = value->axis();
return reinterpret_cast<OpParameter *>(split_param);
param->split_dim_ = value->axis();
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_Split, PopulateSplitParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Split, PopulateSplitParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 29
- 20
mindspore/lite/src/ops/populate/split_with_overlap_populate.cc View File

@@ -20,48 +20,57 @@ using mindspore::schema::PrimitiveType_SplitWithOverlap;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
auto *split_with_over_lap_param =
reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter)));
if (split_with_over_lap_param == nullptr) {
MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed.";
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_SplitWithOverlap();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter));


auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SplitWithOverlap();
split_with_over_lap_param->op_parameter_.type_ = primitive->value_type();
auto *param = reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(SplitWithOverlapParameter));


param->op_parameter_.type_ = primitive->value_type();
auto ratio = value->ratio(); auto ratio = value->ratio();
if (ratio == nullptr) {
MS_LOG(ERROR) << "ratio is nullptr";
free(param);
return nullptr;
}
if (ratio->size() > SPLIT_MAX_SLICE_NUM) { if (ratio->size() > SPLIT_MAX_SLICE_NUM) {
MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM
<< " slices"; << " slices";
delete split_with_over_lap_param;
free(param);
return nullptr; return nullptr;
} }

split_with_over_lap_param->num_split_ = static_cast<int>(ratio->size());
split_with_over_lap_param->split_dim_ = value->split_dim();
param->num_split_ = static_cast<int>(ratio->size());
param->split_dim_ = value->split_dim();


auto extend_top = value->extend_top(); auto extend_top = value->extend_top();
auto extend_bottom = value->extend_bottom(); auto extend_bottom = value->extend_bottom();
if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) {
if (extend_top->size() != ratio->size() || (extend_bottom != nullptr && extend_bottom->size() != ratio->size())) {
MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical";
delete split_with_over_lap_param;
free(param);
return nullptr; return nullptr;
} }


for (size_t i = 0; i < ratio->size(); ++i) { for (size_t i = 0; i < ratio->size(); ++i) {
split_with_over_lap_param->ratio_[i] = (*ratio)[i];
split_with_over_lap_param->extend_top_[i] = (*extend_top)[i];
split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i];
param->ratio_[i] = (*ratio)[i];
param->extend_top_[i] = (*extend_top)[i];
param->extend_bottom_[i] = (*extend_bottom)[i];
} }


split_with_over_lap_param->stride_ = value->stride();
split_with_over_lap_param->pad_top_ = value->pad_top();
param->stride_ = value->stride();
param->pad_top_ = value->pad_top();


return reinterpret_cast<OpParameter *>(split_with_over_lap_param);
return reinterpret_cast<OpParameter *>(param);
} }

REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 19
- 20
mindspore/lite/src/ops/populate/squeeze_populate.cc View File

@@ -19,36 +19,35 @@ using mindspore::schema::PrimitiveType_Squeeze;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateSqueezeParameter(const void *prim) { OpParameter *PopulateSqueezeParameter(const void *prim) {
SqueezeParameter *squeeze_param = reinterpret_cast<SqueezeParameter *>(malloc(sizeof(SqueezeParameter)));
if (squeeze_param == nullptr) {
MS_LOG(ERROR) << "malloc SqueezeParameter failed.";
return nullptr;
}
memset(squeeze_param, 0, sizeof(SqueezeParameter));
auto *primitive = static_cast<const schema::Primitive *>(prim); auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
squeeze_param->op_parameter_.type_ = primitive->value_type();
auto value = primitive->value_as_Squeeze();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}


auto squeeze_prim = primitive->value_as_Squeeze();
if (squeeze_prim == nullptr) {
MS_LOG(ERROR) << "squeeze_prim is nullptr";
auto *param = reinterpret_cast<SqueezeParameter *>(malloc(sizeof(SqueezeParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc SqueezeParameter failed.";
return nullptr; return nullptr;
} }
auto axis = squeeze_prim->axis();
if (squeeze_prim->axis() != nullptr) {
squeeze_param->axis_size_ = axis->size();
for (size_t i = 0; i < squeeze_param->axis_size_; i++) {
squeeze_param->axis_[i] = *(axis->begin() + i);
memset(param, 0, sizeof(SqueezeParameter));

param->op_parameter_.type_ = primitive->value_type();
auto axis = value->axis();
if (axis != nullptr) {
param->axis_size_ = axis->size();
for (size_t i = 0; i < param->axis_size_; i++) {
param->axis_[i] = *(axis->begin() + i);
} }
} else { } else {
squeeze_param->axis_size_ = 0;
param->axis_size_ = 0;
} }

return reinterpret_cast<OpParameter *>(squeeze_param);
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_Squeeze, PopulateSqueezeParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Squeeze, PopulateSqueezeParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 12
- 11
mindspore/lite/src/ops/populate/stack_populate.cc View File

@@ -19,14 +19,7 @@ using mindspore::schema::PrimitiveType_Stack;


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
OpParameter *PopulateStackParameter(const void *prim) { OpParameter *PopulateStackParameter(const void *prim) {
auto *stack_param = reinterpret_cast<StackParameter *>(malloc(sizeof(StackParameter)));
if (stack_param == nullptr) {
MS_LOG(ERROR) << "malloc StackParameter failed.";
return nullptr;
}
memset(stack_param, 0, sizeof(StackParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Stack(); auto value = primitive->value_as_Stack();
@@ -34,11 +27,19 @@ OpParameter *PopulateStackParameter(const void *prim) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return nullptr; return nullptr;
} }
stack_param->op_parameter_.type_ = primitive->value_type();
stack_param->axis_ = static_cast<int>(value->axis());
return reinterpret_cast<OpParameter *>(stack_param);

auto *param = reinterpret_cast<StackParameter *>(malloc(sizeof(StackParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc StackParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(StackParameter));

param->op_parameter_.type_ = primitive->value_type();
param->axis_ = static_cast<int>(value->axis());
return reinterpret_cast<OpParameter *>(param);
} }
} // namespace
REG_POPULATE(PrimitiveType_Stack, PopulateStackParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Stack, PopulateStackParameter, SCHEMA_CUR)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save