| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_ActivationGrad; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateActivationGradParameter(const void *prim) { | OpParameter *PopulateActivationGradParameter(const void *prim) { | ||||
| auto *act_param = reinterpret_cast<ActivationGradParameter *>(malloc(sizeof(ActivationGradParameter))); | |||||
| if (act_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ActivationParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(act_param, 0, sizeof(ActivationGradParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_ActivationGrad(); | auto value = primitive->value_as_ActivationGrad(); | ||||
| @@ -34,11 +27,20 @@ OpParameter *PopulateActivationGradParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| act_param->op_parameter.type_ = primitive->value_type(); | |||||
| act_param->type_ = static_cast<int>(value->activation_type()); | |||||
| act_param->alpha_ = value->alpha(); | |||||
| return reinterpret_cast<OpParameter *>(act_param); | |||||
| auto *param = reinterpret_cast<ActivationGradParameter *>(malloc(sizeof(ActivationGradParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ActivationParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ActivationGradParameter)); | |||||
| param->op_parameter.type_ = primitive->value_type(); | |||||
| param->type_ = static_cast<int>(value->activation_type()); | |||||
| param->alpha_ = value->alpha(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_ActivationGrad, PopulateActivationGradParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_ActivationGrad, PopulateActivationGradParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,29 +19,29 @@ using mindspore::schema::PrimitiveType_Activation; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateRelu6Parameter(const void *prim) { | OpParameter *PopulateRelu6Parameter(const void *prim) { | ||||
| auto *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter))); | |||||
| if (act_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ActivationParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(act_param, 0, sizeof(ActivationParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| act_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto acti_prim = primitive->value_as_Activation(); | |||||
| if (acti_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "acti_prim is nullptr"; | |||||
| auto value = primitive->value_as_Activation(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ActivationParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| act_param->type_ = static_cast<int>(acti_prim->activation_type()); | |||||
| act_param->alpha_ = acti_prim->alpha(); | |||||
| act_param->min_val_ = acti_prim->min_val(); | |||||
| act_param->max_val_ = acti_prim->max_val(); | |||||
| return reinterpret_cast<OpParameter *>(act_param); | |||||
| memset(param, 0, sizeof(ActivationParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->type_ = static_cast<int>(value->activation_type()); | |||||
| param->alpha_ = value->alpha(); | |||||
| param->min_val_ = value->min_val(); | |||||
| param->max_val_ = value->max_val(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Activation, PopulateRelu6Parameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Activation, PopulateRelu6Parameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,14 +19,16 @@ using mindspore::schema::PrimitiveType_Adam; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateAdamParameter(const void *prim) { | OpParameter *PopulateAdamParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc Adam Parameter failed."; | MS_LOG(ERROR) << "malloc Adam Parameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | return param; | ||||
| } | } | ||||
| @@ -20,24 +20,25 @@ using mindspore::schema::PrimitiveType_AddFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateAddParameter(const void *prim) { | OpParameter *PopulateAddParameter(const void *prim) { | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| auto value = primitive->value_as_AddFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); | ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| auto add_prim = primitive->value_as_AddFusion(); | |||||
| if (add_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "add_prim is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| param->activation_type_ = add_prim->activation_type(); | |||||
| param->activation_type_ = value->activation_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_AddFusion, PopulateAddParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_AddFusion, PopulateAddParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,56 +21,59 @@ using mindspore::schema::PrimitiveType_AdderFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateAdderParameter(const void *prim) { | OpParameter *PopulateAdderParameter(const void *prim) { | ||||
| ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto value = primitive->value_as_AdderFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| conv_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto conv_primitive = primitive->value_as_AdderFusion(); | |||||
| if (conv_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "conv_primitive is nullptr"; | |||||
| auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto kernel_size = conv_primitive->kernel_size(); | |||||
| auto stride = conv_primitive->stride(); | |||||
| auto pad_list = conv_primitive->pad_list(); | |||||
| auto dilation = conv_primitive->dilation(); | |||||
| memset(param, 0, sizeof(ConvParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto kernel_size = value->kernel_size(); | |||||
| auto stride = value->stride(); | |||||
| auto pad_list = value->pad_list(); | |||||
| auto dilation = value->dilation(); | |||||
| if (kernel_size == nullptr || stride == nullptr || pad_list == nullptr || dilation == nullptr) { | if (kernel_size == nullptr || stride == nullptr || pad_list == nullptr || dilation == nullptr) { | ||||
| MS_LOG(ERROR) << "nullptr"; | MS_LOG(ERROR) << "nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| conv_param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| conv_param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| conv_param->group_ = static_cast<int>(conv_primitive->group()); | |||||
| conv_param->stride_h_ = static_cast<int>(*(stride->begin())); | |||||
| conv_param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||||
| conv_param->pad_u_ = static_cast<int>(*(pad_list->begin())); | |||||
| conv_param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1)); | |||||
| conv_param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2)); | |||||
| conv_param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3)); | |||||
| conv_param->dilation_h_ = static_cast<int>(*(dilation->begin())); | |||||
| conv_param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1)); | |||||
| conv_param->input_channel_ = static_cast<int>(conv_primitive->in_channel()); | |||||
| conv_param->output_channel_ = static_cast<int>(conv_primitive->out_channel()); | |||||
| auto act_type = conv_primitive->activation_type(); | |||||
| param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| param->group_ = static_cast<int>(value->group()); | |||||
| param->stride_h_ = static_cast<int>(*(stride->begin())); | |||||
| param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||||
| param->pad_u_ = static_cast<int>(*(pad_list->begin())); | |||||
| param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1)); | |||||
| param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2)); | |||||
| param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3)); | |||||
| param->dilation_h_ = static_cast<int>(*(dilation->begin())); | |||||
| param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1)); | |||||
| param->input_channel_ = static_cast<int>(value->in_channel()); | |||||
| param->output_channel_ = static_cast<int>(value->out_channel()); | |||||
| auto act_type = value->activation_type(); | |||||
| switch (act_type) { | switch (act_type) { | ||||
| case schema::ActivationType_RELU: | case schema::ActivationType_RELU: | ||||
| conv_param->act_type_ = ActType_Relu; | |||||
| param->act_type_ = ActType_Relu; | |||||
| break; | break; | ||||
| case schema::ActivationType_RELU6: | case schema::ActivationType_RELU6: | ||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| param->act_type_ = ActType_Relu6; | |||||
| break; | break; | ||||
| default: | default: | ||||
| conv_param->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| break; | break; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_AdderFusion, PopulateAdderParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_AdderFusion, PopulateAdderParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,20 +19,21 @@ using mindspore::schema::PrimitiveType_AddN; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateAddNParameter(const void *prim) { | OpParameter *PopulateAddNParameter(const void *prim) { | ||||
| auto *addn_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (addn_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | MS_LOG(ERROR) << "malloc OpParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(addn_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| addn_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(addn_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_AddN, PopulateAddNParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_AddN, PopulateAddNParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,7 +19,6 @@ using mindspore::schema::PrimitiveType_ArgMaxFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateArgMaxParameter(const void *prim) { | OpParameter *PopulateArgMaxParameter(const void *prim) { | ||||
| auto *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); | auto *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); | ||||
| if (arg_param == nullptr) { | if (arg_param == nullptr) { | ||||
| @@ -32,6 +31,7 @@ OpParameter *PopulateArgMaxParameter(const void *prim) { | |||||
| auto param = primitive->value_as_ArgMaxFusion(); | auto param = primitive->value_as_ArgMaxFusion(); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | MS_LOG(ERROR) << "param is nullptr"; | ||||
| free(arg_param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| arg_param->axis_ = param->axis(); | arg_param->axis_ = param->axis(); | ||||
| @@ -41,7 +41,6 @@ OpParameter *PopulateArgMaxParameter(const void *prim) { | |||||
| arg_param->get_max_ = true; | arg_param->get_max_ = true; | ||||
| return reinterpret_cast<OpParameter *>(arg_param); | return reinterpret_cast<OpParameter *>(arg_param); | ||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ArgMaxFusion, PopulateArgMaxParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ArgMaxFusion, PopulateArgMaxParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,29 +19,29 @@ using mindspore::schema::PrimitiveType_ArgMinFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateArgMinParameter(const void *prim) { | OpParameter *PopulateArgMinParameter(const void *prim) { | ||||
| ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); | |||||
| if (arg_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed."; | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| auto value = primitive->value_as_ArgMinFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(arg_param, 0, sizeof(ArgMinMaxParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| arg_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_ArgMinFusion(); | |||||
| auto *param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| arg_param->axis_ = param->axis(); | |||||
| arg_param->topk_ = param->top_k(); | |||||
| arg_param->out_value_ = param->out_max_value(); | |||||
| arg_param->keep_dims_ = param->keep_dims(); | |||||
| arg_param->get_max_ = false; | |||||
| return reinterpret_cast<OpParameter *>(arg_param); | |||||
| memset(param, 0, sizeof(ArgMinMaxParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = value->axis(); | |||||
| param->topk_ = value->top_k(); | |||||
| param->out_value_ = value->out_max_value(); | |||||
| param->keep_dims_ = value->keep_dims(); | |||||
| param->get_max_ = false; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ArgMinFusion, PopulateArgMinParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ArgMinFusion, PopulateArgMinParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -36,14 +36,16 @@ using mindspore::schema::PrimitiveType_SquaredDifference; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| ArithmeticParameter *PopulateArithmeticCommonPara(const void *prim) { | ArithmeticParameter *PopulateArithmeticCommonPara(const void *prim) { | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(ArithmeticParameter)); | memset(param, 0, sizeof(ArithmeticParameter)); | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| param->broadcasting_ = false; | param->broadcasting_ = false; | ||||
| param->ndim_ = 0; | param->ndim_ = 0; | ||||
| @@ -35,16 +35,18 @@ using mindspore::schema::PrimitiveType_Square; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateArithmeticSelf(const void *prim) { | OpParameter *PopulateArithmeticSelf(const void *prim) { | ||||
| auto *arithmetic_self_param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter))); | |||||
| if (arithmetic_self_param == nullptr) { | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed."; | MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| arithmetic_self_param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_self_param); | |||||
| memset(param, 0, sizeof(ArithmeticSelfParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Abs, PopulateArithmeticSelf, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Abs, PopulateArithmeticSelf, SCHEMA_CUR) | ||||
| @@ -18,20 +18,21 @@ using mindspore::schema::PrimitiveType_Assert; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateAssertParameter(const void *prim) { | OpParameter *PopulateAssertParameter(const void *prim) { | ||||
| auto *assert_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (assert_parameter == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc AssertParameter failed."; | MS_LOG(ERROR) << "malloc AssertParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(assert_parameter, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| assert_parameter->type_ = primitive->value_type(); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| return reinterpret_cast<OpParameter *>(assert_parameter); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Assert, PopulateAssertParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Assert, PopulateAssertParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,14 +19,16 @@ using mindspore::schema::PrimitiveType_AssignAdd; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateAssignAddParameter(const void *prim) { | OpParameter *PopulateAssignAddParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc AssignAdd Parameter failed."; | MS_LOG(ERROR) << "malloc AssignAdd Parameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | return param; | ||||
| } | } | ||||
| @@ -19,6 +19,9 @@ using mindspore::schema::PrimitiveType_Assign; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateAssignParameter(const void *prim) { | OpParameter *PopulateAssignParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc Assign Parameter failed."; | MS_LOG(ERROR) << "malloc Assign Parameter failed."; | ||||
| @@ -26,8 +29,6 @@ OpParameter *PopulateAssignParameter(const void *prim) { | |||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | return param; | ||||
| } | } | ||||
| @@ -21,25 +21,27 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| OpParameter *PopulateAudioSpectrogramParameter(const void *prim) { | OpParameter *PopulateAudioSpectrogramParameter(const void *prim) { | ||||
| auto *arg_param = reinterpret_cast<AudioSpectrogramParameter *>(malloc(sizeof(AudioSpectrogramParameter))); | |||||
| if (arg_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc AudioSpectrogramParameter failed."; | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| auto value = primitive->value_as_AudioSpectrogram(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(arg_param, 0, sizeof(AudioSpectrogramParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| arg_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_AudioSpectrogram(); | |||||
| auto *param = reinterpret_cast<AudioSpectrogramParameter *>(malloc(sizeof(AudioSpectrogramParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc AudioSpectrogramParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| arg_param->window_size_ = param->window_size(); | |||||
| arg_param->stride_ = param->stride(); | |||||
| return reinterpret_cast<OpParameter *>(arg_param); | |||||
| memset(param, 0, sizeof(AudioSpectrogramParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->window_size_ = value->window_size(); | |||||
| param->stride_ = value->stride(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| REG_POPULATE(PrimitiveType_AudioSpectrogram, PopulateAudioSpectrogramParameter, SCHEMA_CUR); | |||||
| REG_POPULATE(PrimitiveType_AudioSpectrogram, PopulateAudioSpectrogramParameter, SCHEMA_CUR) | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,27 +19,27 @@ using mindspore::schema::PrimitiveType_BatchNorm; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateBatchNorm(const void *prim) { | OpParameter *PopulateBatchNorm(const void *prim) { | ||||
| auto *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||||
| if (batch_norm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(batch_norm_param, 0, sizeof(BatchNormParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| batch_norm_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto prim_batchnorm = primitive->value_as_BatchNorm(); | |||||
| if (prim_batchnorm == nullptr) { | |||||
| MS_LOG(ERROR) << "prim_batchnorm is nullptr"; | |||||
| auto value = primitive->value_as_BatchNorm(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchNormParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| batch_norm_param->epsilon_ = prim_batchnorm->epsilon(); | |||||
| batch_norm_param->fused_ = false; | |||||
| return reinterpret_cast<OpParameter *>(batch_norm_param); | |||||
| memset(param, 0, sizeof(BatchNormParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->epsilon_ = value->epsilon(); | |||||
| param->fused_ = false; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_BatchNorm, PopulateBatchNorm, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_BatchNorm, PopulateBatchNorm, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,48 +20,52 @@ using mindspore::schema::PrimitiveType_BatchToSpaceND; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateBatchToSpaceParameter(const void *prim) { | OpParameter *PopulateBatchToSpaceParameter(const void *prim) { | ||||
| auto *batch_space_param = reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter))); | |||||
| if (batch_space_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(batch_space_param, 0, sizeof(BatchToSpaceParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| batch_space_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_BatchToSpace(); | |||||
| auto value = primitive->value_as_BatchToSpace(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto block_size = param->block_size(); | |||||
| memset(param, 0, sizeof(BatchToSpaceParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto block_size = value->block_size(); | |||||
| if (block_size == nullptr) { | if (block_size == nullptr) { | ||||
| return reinterpret_cast<OpParameter *>(batch_space_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| auto block_shape = std::vector<int64_t>(block_size->begin(), block_size->end()); | auto block_shape = std::vector<int64_t>(block_size->begin(), block_size->end()); | ||||
| if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { | if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; | MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; | ||||
| free(batch_space_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto crop = param->crops(); | |||||
| auto crop = value->crops(); | |||||
| if (crop == nullptr) { | if (crop == nullptr) { | ||||
| MS_LOG(ERROR) << "crop is nullptr"; | MS_LOG(ERROR) << "crop is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto fb_crops = crop->data(); | auto fb_crops = crop->data(); | ||||
| if (fb_crops == nullptr) { | if (fb_crops == nullptr) { | ||||
| MS_LOG(ERROR) << "fb_crops is nullptr"; | MS_LOG(ERROR) << "fb_crops is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<int64_t> crops; | std::vector<int64_t> crops; | ||||
| for (auto iter = fb_crops->begin(); iter != fb_crops->end(); ++iter) { | |||||
| auto crops_data = (*iter)->data(); | |||||
| for (auto fb_crop : *fb_crops) { | |||||
| auto crops_data = fb_crop->data(); | |||||
| if (crops_data == nullptr) { | if (crops_data == nullptr) { | ||||
| MS_LOG(ERROR) << "crops_data is nullptr"; | MS_LOG(ERROR) << "crops_data is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto crops_vec = std::vector<int64_t>(crops_data->begin(), crops_data->end()); | auto crops_vec = std::vector<int64_t>(crops_data->begin(), crops_data->end()); | ||||
| @@ -69,20 +73,20 @@ OpParameter *PopulateBatchToSpaceParameter(const void *prim) { | |||||
| } | } | ||||
| if (crops.size() != COMM_SHAPE_SIZE) { | if (crops.size() != COMM_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE; | MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE; | ||||
| free(batch_space_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { | for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { | ||||
| batch_space_param->block_shape_[i] = static_cast<int>(block_shape[i]); | |||||
| param->block_shape_[i] = static_cast<int>(block_shape[i]); | |||||
| } | } | ||||
| for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { | for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { | ||||
| batch_space_param->crops_[i] = static_cast<int>(crops[i]); | |||||
| param->crops_[i] = static_cast<int>(crops[i]); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(batch_space_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter, SCHEMA_CUR) | ||||
| REG_POPULATE(PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,21 +19,21 @@ using mindspore::schema::PrimitiveType_BiasAdd; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateBiasAddParameter(const void *prim) { | OpParameter *PopulateBiasAddParameter(const void *prim) { | ||||
| auto *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->value_type(); | |||||
| memset(param, 0, sizeof(ArithmeticParameter)); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_BiasAdd, PopulateBiasAddParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_BiasAdd, PopulateBiasAddParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,21 +19,20 @@ using mindspore::schema::PrimitiveType_BiasAddGrad; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateBiasAddGradParameter(const void *prim) { | OpParameter *PopulateBiasAddGradParameter(const void *prim) { | ||||
| auto *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| memset(param, 0, sizeof(ArithmeticParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_BiasAddGrad, PopulateBiasAddGradParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_BiasAddGrad, PopulateBiasAddGradParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,27 +19,26 @@ using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) { | OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) { | ||||
| auto *bce_param = | |||||
| reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter))); | |||||
| if (bce_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(bce_param, 0, sizeof(BinaryCrossEntropyGradParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| bce_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_BinaryCrossEntropyGrad(); | |||||
| if (param == nullptr) { | |||||
| auto value = primitive->value_as_BinaryCrossEntropyGrad(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "param is nullptr"; | MS_LOG(ERROR) << "param is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| bce_param->reduction = param->reduction(); | |||||
| return reinterpret_cast<OpParameter *>(bce_param); | |||||
| auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(BinaryCrossEntropyGradParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->reduction = value->reduction(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_BinaryCrossEntropy; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) { | OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) { | ||||
| BinaryCrossEntropyParameter *bce_param = | |||||
| reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter))); | |||||
| if (bce_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(bce_param, 0, sizeof(BinaryCrossEntropyParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_BinaryCrossEntropy(); | auto value = primitive->value_as_BinaryCrossEntropy(); | ||||
| @@ -34,9 +27,17 @@ OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| bce_param->op_parameter_.type_ = primitive->value_type(); | |||||
| bce_param->reduction = value->reduction(); | |||||
| return reinterpret_cast<OpParameter *>(bce_param); | |||||
| auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(BinaryCrossEntropyParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->reduction = value->reduction(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR); | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_BroadcastTo; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateBroadcastToParameter(const void *prim) { | OpParameter *PopulateBroadcastToParameter(const void *prim) { | ||||
| auto *broadcast_param = reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter))); | |||||
| if (broadcast_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BroadcastToParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(broadcast_param, 0, sizeof(BroadcastToParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_BroadcastTo(); | auto value = primitive->value_as_BroadcastTo(); | ||||
| @@ -33,17 +27,26 @@ OpParameter *PopulateBroadcastToParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| broadcast_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto *param = reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BroadcastToParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(BroadcastToParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto dst_shape = value->shape(); | auto dst_shape = value->shape(); | ||||
| if (dst_shape == nullptr) { | if (dst_shape == nullptr) { | ||||
| MS_LOG(ERROR) << "dst_shape is nullptr"; | MS_LOG(ERROR) << "dst_shape is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| broadcast_param->shape_size_ = dst_shape->size(); | |||||
| for (size_t i = 0; i < broadcast_param->shape_size_; ++i) { | |||||
| broadcast_param->shape_[i] = dst_shape->Get(i); | |||||
| param->shape_size_ = dst_shape->size(); | |||||
| for (size_t i = 0; i < param->shape_size_; ++i) { | |||||
| param->shape_[i] = dst_shape->Get(i); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(broadcast_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_BroadcastTo, PopulateBroadcastToParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_BroadcastTo, PopulateBroadcastToParameter, SCHEMA_CUR) | ||||
| @@ -19,16 +19,20 @@ using mindspore::schema::PrimitiveType_Call; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateCallParameter(const void *prim) { | OpParameter *PopulateCallParameter(const void *prim) { | ||||
| OpParameter *call_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (call_parameter == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CallParameter failed."; | MS_LOG(ERROR) << "malloc CallParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(call_parameter, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| call_parameter->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(call_parameter); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Call, PopulateCallParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Call, PopulateCallParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Cast; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateCastParameter(const void *prim) { | OpParameter *PopulateCastParameter(const void *prim) { | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *cast_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *cast_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (cast_param == nullptr) { | if (cast_param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc CastParameter failed."; | MS_LOG(ERROR) << "malloc CastParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(cast_param, 0, sizeof(OpParameter)); | memset(cast_param, 0, sizeof(OpParameter)); | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| cast_param->type_ = primitive->value_type(); | cast_param->type_ = primitive->value_type(); | ||||
| return reinterpret_cast<OpParameter *>(cast_param); | return reinterpret_cast<OpParameter *>(cast_param); | ||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Cast, PopulateCastParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Cast, PopulateCastParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Clip; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateClipParameter(const void *prim) { | OpParameter *PopulateClipParameter(const void *prim) { | ||||
| auto *act_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (act_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ClipParameter failed."; | MS_LOG(ERROR) << "malloc ClipParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(act_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| act_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(act_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Clip, PopulateClipParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Clip, PopulateClipParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_ZerosLike; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateCommonParameter(const void *prim) { | OpParameter *PopulateCommonParameter(const void *prim) { | ||||
| auto *common_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (common_parameter == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | MS_LOG(ERROR) << "malloc OpParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(common_parameter, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| common_parameter->type_ = primitive->value_type(); | |||||
| return common_parameter; | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return param; | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR) | ||||
| REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR) | ||||
| @@ -19,26 +19,26 @@ using mindspore::schema::PrimitiveType_Concat; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateConcatParameter(const void *prim) { | OpParameter *PopulateConcatParameter(const void *prim) { | ||||
| auto *concat_param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter))); | |||||
| if (concat_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConcatParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(concat_param, 0, sizeof(ConcatParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| concat_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_Concat(); | |||||
| if (param == nullptr) { | |||||
| auto value = primitive->value_as_Concat(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "param is nullptr"; | MS_LOG(ERROR) << "param is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| concat_param->axis_ = static_cast<int>(param->axis()); | |||||
| return reinterpret_cast<OpParameter *>(concat_param); | |||||
| auto *param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConcatParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ConcatParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = static_cast<int>(value->axis()); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Concat, PopulateConcatParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Concat, PopulateConcatParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -17,39 +17,42 @@ | |||||
| #include "nnacl/constant_of_shape_parameter.h" | #include "nnacl/constant_of_shape_parameter.h" | ||||
| using mindspore::schema::PrimitiveType_ConstantOfShape; | using mindspore::schema::PrimitiveType_ConstantOfShape; | ||||
| namespace mindspore::lite { | |||||
| namespace { | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateConstantOfShapeParameter(const void *prim) { | OpParameter *PopulateConstantOfShapeParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto value = primitive->value_as_ConstantOfShape(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<ConstantOfShapeParameter *>(malloc(sizeof(ConstantOfShapeParameter))); | auto *param = reinterpret_cast<ConstantOfShapeParameter *>(malloc(sizeof(ConstantOfShapeParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed."; | MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(ConstantOfShapeParameter)); | memset(param, 0, sizeof(ConstantOfShapeParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| auto attr = primitive->value_as_ConstantOfShape(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "attr is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto val = attr->value(); | |||||
| if (val == nullptr) { | |||||
| auto prim_val = value->value(); | |||||
| if (prim_val == nullptr) { | |||||
| MS_LOG(ERROR) << "val is nullptr"; | MS_LOG(ERROR) << "val is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto value = std::vector<float>(val->begin(), val->end()); | |||||
| param->data_type_ = static_cast<int>(attr->data_type()); | |||||
| if (value.empty() || value.size() > 1) { | |||||
| auto val = std::vector<float>(prim_val->begin(), prim_val->end()); | |||||
| param->data_type_ = static_cast<int>(value->data_type()); | |||||
| if (val.empty() || val.size() > 1) { | |||||
| MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; | MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; | ||||
| } else { | } else { | ||||
| switch (param->data_type_) { | switch (param->data_type_) { | ||||
| case kNumberTypeFloat32: | case kNumberTypeFloat32: | ||||
| param->value_.f32_value_ = *(val->begin()); | |||||
| param->value_.f32_value_ = *(prim_val->begin()); | |||||
| break; | break; | ||||
| case kNumberTypeInt32: | case kNumberTypeInt32: | ||||
| param->value_.int32_value_ = *(val->begin()); | |||||
| param->value_.int32_value_ = *(prim_val->begin()); | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "The value of constant of shape is invalid"; | MS_LOG(ERROR) << "The value of constant of shape is invalid"; | ||||
| @@ -57,6 +60,7 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) { | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter, SCHEMA_CUR); | ||||
| } // namespace mindspore::lite | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -20,74 +20,76 @@ using mindspore::schema::PrimitiveType_Conv2DFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateConvParameter(const void *prim) { | OpParameter *PopulateConvParameter(const void *prim) { | ||||
| auto *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| conv_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto conv_primitive = primitive->value_as_Conv2DFusion(); | |||||
| if (conv_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "conv_primitive is nullptr"; | |||||
| auto value = primitive->value_as_Conv2DFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto kernel_size = conv_primitive->kernel_size(); | |||||
| auto stride = conv_primitive->stride(); | |||||
| auto pad_list = conv_primitive->pad_list(); | |||||
| auto dilation = conv_primitive->dilation(); | |||||
| auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ConvParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto kernel_size = value->kernel_size(); | |||||
| auto stride = value->stride(); | |||||
| auto pad_list = value->pad_list(); | |||||
| auto dilation = value->dilation(); | |||||
| if (kernel_size == nullptr || stride == nullptr || dilation == nullptr) { | if (kernel_size == nullptr || stride == nullptr || dilation == nullptr) { | ||||
| MS_LOG(ERROR) << "nullptr"; | MS_LOG(ERROR) << "nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| conv_param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| conv_param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| conv_param->group_ = static_cast<int>(conv_primitive->group()); | |||||
| conv_param->stride_h_ = static_cast<int>(*(stride->begin())); | |||||
| conv_param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||||
| switch (conv_primitive->pad_mode()) { | |||||
| param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| param->group_ = static_cast<int>(value->group()); | |||||
| param->stride_h_ = static_cast<int>(*(stride->begin())); | |||||
| param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | case schema::PadMode_SAME: | ||||
| conv_param->pad_mode_ = Pad_same; | |||||
| param->pad_mode_ = Pad_same; | |||||
| break; | break; | ||||
| case schema::PadMode_VALID: | case schema::PadMode_VALID: | ||||
| conv_param->pad_mode_ = Pad_valid; | |||||
| param->pad_mode_ = Pad_valid; | |||||
| break; | break; | ||||
| default: | default: | ||||
| conv_param->pad_mode_ = Pad_pad; | |||||
| param->pad_mode_ = Pad_pad; | |||||
| } | } | ||||
| if (pad_list == nullptr || pad_list->size() < 4) { | if (pad_list == nullptr || pad_list->size() < 4) { | ||||
| conv_param->pad_u_ = 0; | |||||
| conv_param->pad_d_ = 0; | |||||
| conv_param->pad_l_ = 0; | |||||
| conv_param->pad_r_ = 0; | |||||
| param->pad_u_ = 0; | |||||
| param->pad_d_ = 0; | |||||
| param->pad_l_ = 0; | |||||
| param->pad_r_ = 0; | |||||
| } else { | } else { | ||||
| conv_param->pad_u_ = static_cast<int>(*(pad_list->begin())); | |||||
| conv_param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1)); | |||||
| conv_param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2)); | |||||
| conv_param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3)); | |||||
| param->pad_u_ = static_cast<int>(*(pad_list->begin())); | |||||
| param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1)); | |||||
| param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2)); | |||||
| param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3)); | |||||
| } | } | ||||
| conv_param->dilation_h_ = static_cast<int>(*(dilation->begin())); | |||||
| conv_param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1)); | |||||
| conv_param->input_channel_ = static_cast<int>(conv_primitive->in_channel()); | |||||
| conv_param->output_channel_ = static_cast<int>(conv_primitive->out_channel()); | |||||
| auto act_type = conv_primitive->activation_type(); | |||||
| param->dilation_h_ = static_cast<int>(*(dilation->begin())); | |||||
| param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1)); | |||||
| param->input_channel_ = static_cast<int>(value->in_channel()); | |||||
| param->output_channel_ = static_cast<int>(value->out_channel()); | |||||
| auto act_type = value->activation_type(); | |||||
| switch (act_type) { | switch (act_type) { | ||||
| case schema::ActivationType_RELU: | case schema::ActivationType_RELU: | ||||
| conv_param->act_type_ = ActType_Relu; | |||||
| param->act_type_ = ActType_Relu; | |||||
| break; | break; | ||||
| case schema::ActivationType_RELU6: | case schema::ActivationType_RELU6: | ||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| param->act_type_ = ActType_Relu6; | |||||
| break; | break; | ||||
| default: | default: | ||||
| conv_param->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Conv2DFusion, PopulateConvParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Conv2DFusion, PopulateConvParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,29 +16,30 @@ | |||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/resize_parameter.h" | #include "nnacl/resize_parameter.h" | ||||
| using mindspore::schema::PrimitiveType_CropAndResize; | using mindspore::schema::PrimitiveType_CropAndResize; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateCropAndResizeParameter(const void *prim) { | OpParameter *PopulateCropAndResizeParameter(const void *prim) { | ||||
| auto *crop_resize_param = reinterpret_cast<CropAndResizeParameter *>(malloc(sizeof(CropAndResizeParameter))); | |||||
| if (crop_resize_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CropAndResizeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(crop_resize_param, 0, sizeof(CropAndResizeParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| crop_resize_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_CropAndResize(); | |||||
| auto value = primitive->value_as_CropAndResize(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<CropAndResizeParameter *>(malloc(sizeof(CropAndResizeParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc CropAndResizeParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| crop_resize_param->method_ = static_cast<int>(param->method()); | |||||
| crop_resize_param->extrapolation_value_ = param->extrapolation_value(); | |||||
| return reinterpret_cast<OpParameter *>(crop_resize_param); | |||||
| memset(param, 0, sizeof(CropAndResizeParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->method_ = static_cast<int>(value->method()); | |||||
| param->extrapolation_value_ = value->extrapolation_value(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_CropAndResize, PopulateCropAndResizeParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_CropAndResize, PopulateCropAndResizeParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,39 +19,42 @@ using mindspore::schema::PrimitiveType_Crop; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateCropParameter(const void *prim) { | OpParameter *PopulateCropParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto crop_prim = primitive->value_as_Crop(); | |||||
| if (crop_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "crop_prim is nullptr"; | |||||
| auto value = primitive->value_as_Crop(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto param_offset = crop_prim->offsets(); | |||||
| auto *param = reinterpret_cast<CropParameter *>(malloc(sizeof(CropParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CropParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(CropParameter)); | |||||
| auto param_offset = value->offsets(); | |||||
| if (param_offset == nullptr) { | if (param_offset == nullptr) { | ||||
| MS_LOG(ERROR) << "param_offset is nullptr"; | MS_LOG(ERROR) << "param_offset is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (param_offset->size() > COMM_SHAPE_SIZE) { | if (param_offset->size() > COMM_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "crop_param offset size(" << param_offset->size() << ") should <= " << COMM_SHAPE_SIZE; | |||||
| MS_LOG(ERROR) << "param offset size(" << param_offset->size() << ") should <= " << COMM_SHAPE_SIZE; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *crop_param = reinterpret_cast<CropParameter *>(malloc(sizeof(CropParameter))); | |||||
| if (crop_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CropParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(crop_param, 0, sizeof(CropParameter)); | |||||
| crop_param->op_parameter_.type_ = primitive->value_type(); | |||||
| crop_param->axis_ = crop_prim->axis(); | |||||
| crop_param->offset_size_ = param_offset->size(); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = value->axis(); | |||||
| param->offset_size_ = param_offset->size(); | |||||
| for (size_t i = 0; i < param_offset->size(); ++i) { | for (size_t i = 0; i < param_offset->size(); ++i) { | ||||
| crop_param->offset_[i] = *(param_offset->begin() + i); | |||||
| param->offset_[i] = *(param_offset->begin() + i); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(crop_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Crop, PopulateCropParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Crop, PopulateCropParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,22 +19,27 @@ using mindspore::schema::PrimitiveType_CumSum; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateCumSumParameter(const void *prim) { | OpParameter *PopulateCumSumParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| auto cumsum_prim = primitive->value_as_CumSum(); | |||||
| CumSumParameter *cumsum_param = reinterpret_cast<CumSumParameter *>(malloc(sizeof(CumSumParameter))); | |||||
| if (cumsum_param == nullptr) { | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto value = primitive->value_as_CumSum(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<CumSumParameter *>(malloc(sizeof(CumSumParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CumsumParameter failed."; | MS_LOG(ERROR) << "malloc CumsumParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(cumsum_param, 0, sizeof(CumSumParameter)); | |||||
| cumsum_param->op_parameter_.type_ = primitive->value_type(); | |||||
| cumsum_param->exclusive_ = cumsum_prim->exclusive(); | |||||
| cumsum_param->reverse_ = cumsum_prim->reverse(); | |||||
| return reinterpret_cast<OpParameter *>(cumsum_param); | |||||
| memset(param, 0, sizeof(CumSumParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->exclusive_ = value->exclusive(); | |||||
| param->reverse_ = value->reverse(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_CumSum, PopulateCumSumParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_CumSum, PopulateCumSumParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -18,20 +18,21 @@ using mindspore::schema::PrimitiveType_CustomExtractFeatures; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateExtractFeaturesParameter(const void *prim) { | OpParameter *PopulateExtractFeaturesParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "new OpParameter failed."; | MS_LOG(ERROR) << "new OpParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_CustomExtractFeatures, PopulateExtractFeaturesParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_CustomExtractFeatures, PopulateExtractFeaturesParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,17 +19,20 @@ using mindspore::schema::PrimitiveType_CustomNormalize; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateCustomNormalizeParameter(const void *prim) { | OpParameter *PopulateCustomNormalizeParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "new OpParameter failed."; | MS_LOG(ERROR) << "new OpParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_CustomPredict; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateCustomPredictParameter(const void *prim) { | OpParameter *PopulateCustomPredictParameter(const void *prim) { | ||||
| PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PredictParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_CustomPredict(); | auto value = primitive->value_as_CustomPredict(); | ||||
| @@ -33,12 +27,20 @@ OpParameter *PopulateCustomPredictParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PredictParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| param->output_num = value->output_num(); | param->output_num = value->output_num(); | ||||
| param->weight_threshold = value->weight_threshold(); | param->weight_threshold = value->weight_threshold(); | ||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_CustomPredict, PopulateCustomPredictParameter, SCHEMA_CUR); | |||||
| REG_POPULATE(PrimitiveType_CustomPredict, PopulateCustomPredictParameter, SCHEMA_CUR); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,74 +21,77 @@ using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateDeconvParameter(const void *prim) { | OpParameter *PopulateDeconvParameter(const void *prim) { | ||||
| auto *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| conv_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto conv_primitive = primitive->value_as_Conv2dTransposeFusion(); | |||||
| if (conv_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "conv_primitive is nullptr"; | |||||
| auto value = primitive->value_as_Conv2dTransposeFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto kernel_size = conv_primitive->kernel_size(); | |||||
| auto stride = conv_primitive->stride(); | |||||
| auto pad_list = conv_primitive->pad_list(); | |||||
| auto dilation = conv_primitive->dilation(); | |||||
| auto output_paddings = conv_primitive->output_paddings(); | |||||
| memset(param, 0, sizeof(ConvParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto kernel_size = value->kernel_size(); | |||||
| auto stride = value->stride(); | |||||
| auto pad_list = value->pad_list(); | |||||
| auto dilation = value->dilation(); | |||||
| auto output_paddings = value->output_paddings(); | |||||
| if (kernel_size == nullptr || stride == nullptr || dilation == nullptr || output_paddings == nullptr) { | if (kernel_size == nullptr || stride == nullptr || dilation == nullptr || output_paddings == nullptr) { | ||||
| MS_LOG(ERROR) << "nullptr"; | MS_LOG(ERROR) << "nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| conv_param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| conv_param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| conv_param->group_ = static_cast<int>(conv_primitive->group()); | |||||
| conv_param->stride_h_ = static_cast<int>(*(stride->begin())); | |||||
| conv_param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||||
| conv_param->output_padding_h_ = static_cast<int>(*(output_paddings->begin())); | |||||
| conv_param->output_padding_w_ = static_cast<int>(*(output_paddings->begin() + 1)); | |||||
| switch (conv_primitive->pad_mode()) { | |||||
| param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| param->group_ = static_cast<int>(value->group()); | |||||
| param->stride_h_ = static_cast<int>(*(stride->begin())); | |||||
| param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||||
| param->output_padding_h_ = static_cast<int>(*(output_paddings->begin())); | |||||
| param->output_padding_w_ = static_cast<int>(*(output_paddings->begin() + 1)); | |||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | case schema::PadMode_SAME: | ||||
| conv_param->pad_mode_ = Pad_same; | |||||
| param->pad_mode_ = Pad_same; | |||||
| break; | break; | ||||
| case schema::PadMode_VALID: | case schema::PadMode_VALID: | ||||
| conv_param->pad_mode_ = Pad_valid; | |||||
| param->pad_mode_ = Pad_valid; | |||||
| break; | break; | ||||
| default: | default: | ||||
| conv_param->pad_mode_ = Pad_pad; | |||||
| param->pad_mode_ = Pad_pad; | |||||
| } | } | ||||
| if (pad_list == nullptr || pad_list->size() < 4) { | if (pad_list == nullptr || pad_list->size() < 4) { | ||||
| conv_param->pad_u_ = 0; | |||||
| conv_param->pad_d_ = 0; | |||||
| conv_param->pad_l_ = 0; | |||||
| conv_param->pad_r_ = 0; | |||||
| param->pad_u_ = 0; | |||||
| param->pad_d_ = 0; | |||||
| param->pad_l_ = 0; | |||||
| param->pad_r_ = 0; | |||||
| } else { | } else { | ||||
| conv_param->pad_u_ = static_cast<int>(*(pad_list->begin())); | |||||
| conv_param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1)); | |||||
| conv_param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2)); | |||||
| conv_param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3)); | |||||
| param->pad_u_ = static_cast<int>(*(pad_list->begin())); | |||||
| param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1)); | |||||
| param->pad_l_ = static_cast<int>(*(pad_list->begin() + 2)); | |||||
| param->pad_r_ = static_cast<int>(*(pad_list->begin() + 3)); | |||||
| } | } | ||||
| conv_param->dilation_h_ = static_cast<int>(*(dilation->begin())); | |||||
| conv_param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1)); | |||||
| conv_param->input_channel_ = static_cast<int>(conv_primitive->in_channel()); | |||||
| conv_param->output_channel_ = static_cast<int>(conv_primitive->out_channel()); | |||||
| auto act_type = conv_primitive->activation_type(); | |||||
| param->dilation_h_ = static_cast<int>(*(dilation->begin())); | |||||
| param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1)); | |||||
| param->input_channel_ = static_cast<int>(value->in_channel()); | |||||
| param->output_channel_ = static_cast<int>(value->out_channel()); | |||||
| auto act_type = value->activation_type(); | |||||
| switch (act_type) { | switch (act_type) { | ||||
| case schema::ActivationType_RELU: | case schema::ActivationType_RELU: | ||||
| conv_param->act_type_ = ActType_Relu; | |||||
| param->act_type_ = ActType_Relu; | |||||
| break; | break; | ||||
| case schema::ActivationType_RELU6: | case schema::ActivationType_RELU6: | ||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| param->act_type_ = ActType_Relu6; | |||||
| break; | break; | ||||
| default: | default: | ||||
| conv_param->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| break; | break; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Conv2dTransposeFusion, PopulateDeconvParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Conv2dTransposeFusion, PopulateDeconvParameter, SCHEMA_CUR) | ||||
| @@ -22,14 +22,16 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *DefaultPopulateParameter(const void *prim) { | OpParameter *DefaultPopulateParameter(const void *prim) { | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = static_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = static_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "Malloc OpParameter failed."; | MS_LOG(ERROR) << "Malloc OpParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | return param; | ||||
| } | } | ||||
| @@ -21,22 +21,24 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| OpParameter *PopulateDepthToSpaceParameter(const void *prim) { | OpParameter *PopulateDepthToSpaceParameter(const void *prim) { | ||||
| auto *depth_space_param = reinterpret_cast<DepthToSpaceParameter *>(malloc(sizeof(DepthToSpaceParameter))); | |||||
| if (depth_space_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(depth_space_param, 0, sizeof(DepthToSpaceParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto param = primitive->value_as_DepthToSpace(); | |||||
| auto value = primitive->value_as_DepthToSpace(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<DepthToSpaceParameter *>(malloc(sizeof(DepthToSpaceParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| depth_space_param->op_parameter_.type_ = primitive->value_type(); | |||||
| depth_space_param->block_size_ = param->block_size(); | |||||
| return reinterpret_cast<OpParameter *>(depth_space_param); | |||||
| memset(param, 0, sizeof(DepthToSpaceParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->block_size_ = value->block_size(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -19,13 +19,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateConvDwParameter(const void *primitive) { | OpParameter *PopulateConvDwParameter(const void *primitive) { | ||||
| auto *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | MS_LOG(ERROR) << "malloc ConvParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| memset(param, 0, sizeof(ConvParameter)); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,43 +19,43 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateDetectionPostProcessParameter(const void *prim) { | OpParameter *PopulateDetectionPostProcessParameter(const void *prim) { | ||||
| auto *detection_post_process_parameter = | |||||
| reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter))); | |||||
| if (detection_post_process_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| detection_post_process_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_DetectionPostProcess(); | |||||
| auto value = primitive->value_as_DetectionPostProcess(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc DetectionPostProcessParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto scale = param->scale(); | |||||
| memset(param, 0, sizeof(DetectionPostProcessParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto scale = value->scale(); | |||||
| if (scale == nullptr) { | if (scale == nullptr) { | ||||
| MS_LOG(ERROR) << "scale is nullptr"; | MS_LOG(ERROR) << "scale is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| detection_post_process_parameter->h_scale_ = *(scale->begin()); | |||||
| detection_post_process_parameter->w_scale_ = *(scale->begin() + 1); | |||||
| detection_post_process_parameter->x_scale_ = *(scale->begin() + 2); | |||||
| detection_post_process_parameter->y_scale_ = *(scale->begin() + 3); | |||||
| detection_post_process_parameter->nms_iou_threshold_ = param->nms_iou_threshold(); | |||||
| detection_post_process_parameter->nms_score_threshold_ = param->nms_score_threshold(); | |||||
| detection_post_process_parameter->max_detections_ = param->max_detections(); | |||||
| detection_post_process_parameter->detections_per_class_ = param->detections_per_class(); | |||||
| detection_post_process_parameter->max_classes_per_detection_ = param->max_classes_per_detection(); | |||||
| detection_post_process_parameter->num_classes_ = param->num_classes(); | |||||
| detection_post_process_parameter->use_regular_nms_ = param->use_regular_nms(); | |||||
| return reinterpret_cast<OpParameter *>(detection_post_process_parameter); | |||||
| param->h_scale_ = *(scale->begin()); | |||||
| param->w_scale_ = *(scale->begin() + 1); | |||||
| param->x_scale_ = *(scale->begin() + 2); | |||||
| param->y_scale_ = *(scale->begin() + 3); | |||||
| param->nms_iou_threshold_ = value->nms_iou_threshold(); | |||||
| param->nms_score_threshold_ = value->nms_score_threshold(); | |||||
| param->max_detections_ = value->max_detections(); | |||||
| param->detections_per_class_ = value->detections_per_class(); | |||||
| param->max_classes_per_detection_ = value->max_classes_per_detection(); | |||||
| param->num_classes_ = value->num_classes(); | |||||
| param->use_regular_nms_ = value->use_regular_nms(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_DetectionPostProcess, PopulateDetectionPostProcessParameter, SCHEMA_CUR); | |||||
| REG_POPULATE(PrimitiveType_DetectionPostProcess, PopulateDetectionPostProcessParameter, SCHEMA_CUR); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,9 +21,10 @@ namespace lite { | |||||
| OpParameter *PopulateDivParameter(const void *prim) { | OpParameter *PopulateDivParameter(const void *prim) { | ||||
| auto *param = PopulateArithmeticCommonPara(prim); | auto *param = PopulateArithmeticCommonPara(prim); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||||
| MS_LOG(ERROR) << "get PopulateArithmeticCommonPara failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| @@ -19,24 +19,24 @@ using mindspore::schema::PrimitiveType_Eltwise; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateEltwiseParameter(const void *prim) { | OpParameter *PopulateEltwiseParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto value = primitive->value_as_Eltwise(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); | ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto eltwise_param = primitive->value_as_Eltwise(); | |||||
| if (eltwise_param == nullptr) { | |||||
| MS_LOG(ERROR) << "eltwise_param is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| param->eltwise_mode_ = eltwise_param->mode(); | |||||
| param->eltwise_mode_ = value->mode(); | |||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Eltwise, PopulateEltwiseParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Eltwise, PopulateEltwiseParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,26 +19,27 @@ using mindspore::schema::PrimitiveType_Elu; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateEluParameter(const void *prim) { | OpParameter *PopulateEluParameter(const void *prim) { | ||||
| auto *elu_parameter = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter))); | |||||
| if (elu_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(elu_parameter, 0, sizeof(EluParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| elu_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_Elu(); | |||||
| auto value = primitive->value_as_Elu(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc EluParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| elu_parameter->alpha_ = param->alpha(); | |||||
| return reinterpret_cast<OpParameter *>(elu_parameter); | |||||
| memset(param, 0, sizeof(EluParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->alpha_ = value->alpha(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Elu, PopulateEluParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Elu, PopulateEluParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,15 +19,7 @@ using mindspore::schema::PrimitiveType_EmbeddingLookupFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateEmbeddingLookupParameter(const void *prim) { | OpParameter *PopulateEmbeddingLookupParameter(const void *prim) { | ||||
| auto *param = reinterpret_cast<EmbeddingLookupParameter *>(malloc(sizeof(EmbeddingLookupParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(EmbeddingLookupParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_EmbeddingLookupFusion(); | auto value = primitive->value_as_EmbeddingLookupFusion(); | ||||
| @@ -35,6 +27,14 @@ OpParameter *PopulateEmbeddingLookupParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *param = reinterpret_cast<EmbeddingLookupParameter *>(malloc(sizeof(EmbeddingLookupParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(EmbeddingLookupParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| param->max_norm_ = value->max_norm(); | param->max_norm_ = value->max_norm(); | ||||
| if (param->max_norm_ < 0) { | if (param->max_norm_ < 0) { | ||||
| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_ExpFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateExpParameter(const void *prim) { | OpParameter *PopulateExpParameter(const void *prim) { | ||||
| auto *exp_parameter = reinterpret_cast<ExpParameter *>(malloc(sizeof(ExpParameter))); | |||||
| if (exp_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ExpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(exp_parameter, 0, sizeof(ExpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_ExpFusion(); | auto value = primitive->value_as_ExpFusion(); | ||||
| @@ -34,16 +27,24 @@ OpParameter *PopulateExpParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| exp_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| exp_parameter->base_ = value->base(); | |||||
| exp_parameter->scale_ = value->scale(); | |||||
| exp_parameter->shift_ = value->shift(); | |||||
| if (exp_parameter->base_ != -1 && exp_parameter->base_ <= 0) { | |||||
| MS_LOG(ERROR) << "Exp base must be strictly positive, got " << exp_parameter->base_; | |||||
| free(exp_parameter); | |||||
| auto *param = reinterpret_cast<ExpParameter *>(malloc(sizeof(ExpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ExpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ExpParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->base_ = value->base(); | |||||
| param->scale_ = value->scale(); | |||||
| param->shift_ = value->shift(); | |||||
| if (param->base_ != -1 && param->base_ <= 0) { | |||||
| MS_LOG(ERROR) << "Exp base must be strictly positive, got " << param->base_; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(exp_parameter); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_ExpFusion, PopulateExpParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ExpFusion, PopulateExpParameter, SCHEMA_CUR) | ||||
| @@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_ExpandDims; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateExpandDimsParameter(const void *prim) { | OpParameter *PopulateExpandDimsParameter(const void *prim) { | ||||
| auto *expand_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (expand_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ExpandDimsParameter failed."; | MS_LOG(ERROR) << "malloc ExpandDimsParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(expand_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| expand_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(expand_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ExpandDims, PopulateExpandDimsParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ExpandDims, PopulateExpandDimsParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Fill; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateFillParameter(const void *prim) { | OpParameter *PopulateFillParameter(const void *prim) { | ||||
| auto *fill_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (fill_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc FillParameter failed."; | MS_LOG(ERROR) << "malloc FillParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(fill_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| fill_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(fill_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Fill, PopulateFillParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Fill, PopulateFillParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,17 +19,18 @@ using mindspore::schema::PrimitiveType_Flatten; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateFlattenParameter(const void *prim) { | OpParameter *PopulateFlattenParameter(const void *prim) { | ||||
| auto *flatten_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (flatten_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc FlattenParameter failed."; | MS_LOG(ERROR) << "malloc FlattenParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(flatten_param, 0, sizeof(OpParameter)); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| flatten_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(flatten_param); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Flatten, PopulateFlattenParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Flatten, PopulateFlattenParameter, SCHEMA_CUR) | ||||
| @@ -19,37 +19,37 @@ using mindspore::schema::PrimitiveType_FullConnection; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateFullconnectionParameter(const void *prim) { | OpParameter *PopulateFullconnectionParameter(const void *prim) { | ||||
| auto *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (matmul_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MatMulParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(matmul_param, 0, sizeof(MatMulParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| matmul_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto full_conn_prim = primitive->value_as_FullConnection(); | |||||
| if (full_conn_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "full_conn_prim is nullptr"; | |||||
| auto value = primitive->value_as_FullConnection(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MatMulParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| matmul_param->b_transpose_ = true; | |||||
| matmul_param->a_transpose_ = false; | |||||
| matmul_param->has_bias_ = full_conn_prim->has_bias(); | |||||
| if (full_conn_prim->activation_type() == schema::ActivationType_RELU) { | |||||
| matmul_param->act_type_ = ActType_Relu; | |||||
| } else if (full_conn_prim->activation_type() == schema::ActivationType_RELU6) { | |||||
| matmul_param->act_type_ = ActType_Relu6; | |||||
| memset(param, 0, sizeof(MatMulParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->b_transpose_ = true; | |||||
| param->a_transpose_ = false; | |||||
| param->has_bias_ = value->has_bias(); | |||||
| if (value->activation_type() == schema::ActivationType_RELU) { | |||||
| param->act_type_ = ActType_Relu; | |||||
| } else if (value->activation_type() == schema::ActivationType_RELU6) { | |||||
| param->act_type_ = ActType_Relu6; | |||||
| } else { | } else { | ||||
| matmul_param->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| } | } | ||||
| matmul_param->axis_ = full_conn_prim->axis(); | |||||
| matmul_param->use_axis_ = full_conn_prim->use_axis(); | |||||
| return reinterpret_cast<OpParameter *>(matmul_param); | |||||
| param->axis_ = value->axis(); | |||||
| param->use_axis_ = value->use_axis(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_FullConnection, PopulateFullconnectionParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_FullConnection, PopulateFullconnectionParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_FusedBatchNorm; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateFusedBatchNorm(const void *prim) { | OpParameter *PopulateFusedBatchNorm(const void *prim) { | ||||
| auto *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||||
| if (batch_norm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(batch_norm_param, 0, sizeof(BatchNormParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_FusedBatchNorm(); | auto value = primitive->value_as_FusedBatchNorm(); | ||||
| @@ -33,11 +27,19 @@ OpParameter *PopulateFusedBatchNorm(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| batch_norm_param->op_parameter_.type_ = primitive->value_type(); | |||||
| batch_norm_param->epsilon_ = value->epsilon(); | |||||
| batch_norm_param->momentum_ = value->momentum(); | |||||
| batch_norm_param->fused_ = true; | |||||
| return reinterpret_cast<OpParameter *>(batch_norm_param); | |||||
| auto *param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(BatchNormParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->epsilon_ = value->epsilon(); | |||||
| param->momentum_ = value->momentum(); | |||||
| param->fused_ = true; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm, SCHEMA_CUR) | ||||
| @@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_GatherNd; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateGatherNdParameter(const void *prim) { | OpParameter *PopulateGatherNdParameter(const void *prim) { | ||||
| auto *gather_nd_param = reinterpret_cast<GatherNdParameter *>(malloc(sizeof(GatherNdParameter))); | |||||
| if (gather_nd_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<GatherNdParameter *>(malloc(sizeof(GatherNdParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GatherNdParameter failed."; | MS_LOG(ERROR) << "malloc GatherNdParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(gather_nd_param, 0, sizeof(GatherNdParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| gather_nd_param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(gather_nd_param); | |||||
| memset(param, 0, sizeof(GatherNdParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_GatherNd, PopulateGatherNdParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_GatherNd, PopulateGatherNdParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,21 +19,20 @@ using mindspore::schema::PrimitiveType_Gather; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateGatherParameter(const void *prim) { | OpParameter *PopulateGatherParameter(const void *prim) { | ||||
| auto *gather_param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter))); | |||||
| if (gather_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GatherParameter failed."; | MS_LOG(ERROR) << "malloc GatherParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(gather_param, 0, sizeof(GatherParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| gather_param->op_parameter_.type_ = primitive->value_type(); | |||||
| memset(param, 0, sizeof(GatherParameter)); | |||||
| return reinterpret_cast<OpParameter *>(gather_param); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Gather, PopulateGatherParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Gather, PopulateGatherParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,27 +19,26 @@ using mindspore::schema::PrimitiveType_GRU; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateGruParameter(const void *prim) { | OpParameter *PopulateGruParameter(const void *prim) { | ||||
| auto *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter))); | |||||
| if (gru_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GruParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(gru_param, 0, sizeof(GruParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| gru_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_GRU(); | |||||
| if (param == nullptr) { | |||||
| free(gru_param); | |||||
| auto value = primitive->value_as_GRU(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "param is nullptr."; | MS_LOG(ERROR) << "param is nullptr."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| gru_param->bidirectional_ = param->bidirectional(); | |||||
| return reinterpret_cast<OpParameter *>(gru_param); | |||||
| auto *param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GruParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(GruParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->bidirectional_ = value->bidirectional(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_GRU, PopulateGruParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_GRU, PopulateGruParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,17 +19,20 @@ using mindspore::schema::PrimitiveType_HashtableLookup; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateHashtableLookupParameter(const void *prim) { | OpParameter *PopulateHashtableLookupParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "new OpParameter failed."; | MS_LOG(ERROR) << "new OpParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | return param; | ||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_InstanceNorm; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateInstanceNormParameter(const void *prim) { | OpParameter *PopulateInstanceNormParameter(const void *prim) { | ||||
| auto *instance_norm_param = reinterpret_cast<InstanceNormParameter *>(malloc(sizeof(InstanceNormParameter))); | |||||
| if (instance_norm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc InstanceNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(instance_norm_param, 0, sizeof(InstanceNormParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_InstanceNorm(); | auto value = primitive->value_as_InstanceNorm(); | ||||
| @@ -34,9 +27,17 @@ OpParameter *PopulateInstanceNormParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| instance_norm_param->op_parameter_.type_ = primitive->value_type(); | |||||
| instance_norm_param->epsilon_ = value->epsilon(); | |||||
| return reinterpret_cast<OpParameter *>(instance_norm_param); | |||||
| auto *param = reinterpret_cast<InstanceNormParameter *>(malloc(sizeof(InstanceNormParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc InstanceNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(InstanceNormParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->epsilon_ = value->epsilon(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_InstanceNorm, PopulateInstanceNormParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_InstanceNorm, PopulateInstanceNormParameter, SCHEMA_CUR) | ||||
| @@ -13,7 +13,6 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <cstdint> | |||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/l2_norm_parameter.h" | #include "nnacl/l2_norm_parameter.h" | ||||
| using mindspore::schema::PrimitiveType_L2NormalizeFusion; | using mindspore::schema::PrimitiveType_L2NormalizeFusion; | ||||
| @@ -21,13 +20,6 @@ using mindspore::schema::PrimitiveType_L2NormalizeFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateL2NormParameter(const void *prim) { | OpParameter *PopulateL2NormParameter(const void *prim) { | ||||
| auto *l2_norm_parameter = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter))); | |||||
| if (l2_norm_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc L2NormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(l2_norm_parameter, 0, sizeof(L2NormParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_L2NormalizeFusion(); | auto value = primitive->value_as_L2NormalizeFusion(); | ||||
| @@ -35,32 +27,40 @@ OpParameter *PopulateL2NormParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| l2_norm_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto *param = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc L2NormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(L2NormParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto axis_vec = value->axis(); | auto axis_vec = value->axis(); | ||||
| if (axis_vec == nullptr) { | if (axis_vec == nullptr) { | ||||
| MS_LOG(ERROR) << "axis_vec is nullptr"; | MS_LOG(ERROR) << "axis_vec is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| l2_norm_parameter->axis_num_ = axis_vec->size(); | |||||
| param->axis_num_ = axis_vec->size(); | |||||
| MS_ASSERT(axis_vec->size() < 8); | MS_ASSERT(axis_vec->size() < 8); | ||||
| for (size_t i = 0; i < axis_vec->size(); i++) { | for (size_t i = 0; i < axis_vec->size(); i++) { | ||||
| l2_norm_parameter->axis_[i] = static_cast<int>(axis_vec->Get(i)); | |||||
| param->axis_[i] = static_cast<int>(axis_vec->Get(i)); | |||||
| } | } | ||||
| if (value->epsilon() < 1e-6) { | if (value->epsilon() < 1e-6) { | ||||
| l2_norm_parameter->epsilon_ = 1e-6; | |||||
| param->epsilon_ = 1e-6; | |||||
| } else { | } else { | ||||
| l2_norm_parameter->epsilon_ = value->epsilon(); | |||||
| param->epsilon_ = value->epsilon(); | |||||
| } | } | ||||
| if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU)) { | if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU)) { | ||||
| l2_norm_parameter->act_type_ = ActType_Relu; | |||||
| param->act_type_ = ActType_Relu; | |||||
| } else if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU6)) { | } else if (value->activation_type() == static_cast<int>(schema::ActivationType_RELU6)) { | ||||
| l2_norm_parameter->act_type_ = ActType_Relu6; | |||||
| param->act_type_ = ActType_Relu6; | |||||
| } else { | } else { | ||||
| l2_norm_parameter->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(l2_norm_parameter); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_L2NormalizeFusion, PopulateL2NormParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_L2NormalizeFusion, PopulateL2NormParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -21,23 +21,25 @@ using mindspore::schema::PrimitiveType_LayerNormGrad; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateLayerNormGradParameter(const void *prim) { | OpParameter *PopulateLayerNormGradParameter(const void *prim) { | ||||
| auto layer_norm_grad_parameter = reinterpret_cast<LayerNormGradParameter *>(malloc(sizeof(LayerNormGradParameter))); | |||||
| if (layer_norm_grad_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LayerNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(layer_norm_grad_parameter, 0, sizeof(LayerNormGradParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| layer_norm_grad_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_LayerNormGrad(); | |||||
| auto value = primitive->value_as_LayerNormGrad(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto param = reinterpret_cast<LayerNormGradParameter *>(malloc(sizeof(LayerNormGradParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc LayerNormParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| layer_norm_grad_parameter->begin_norm_axis_ = param->begin_norm_axis(); | |||||
| layer_norm_grad_parameter->begin_params_axis_ = param->begin_params_axis(); | |||||
| return reinterpret_cast<OpParameter *>(layer_norm_grad_parameter); | |||||
| memset(param, 0, sizeof(LayerNormGradParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->begin_norm_axis_ = value->begin_norm_axis(); | |||||
| param->begin_params_axis_ = value->begin_params_axis(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_LayerNormGrad, PopulateLayerNormGradParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_LayerNormGrad, PopulateLayerNormGradParameter, SCHEMA_CUR); | ||||
| @@ -17,28 +17,31 @@ | |||||
| #include <cstdint> | #include <cstdint> | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| using mindspore::schema::PrimitiveType_LayerNormFusion; | using mindspore::schema::PrimitiveType_LayerNormFusion; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateLayerNormParameter(const void *prim) { | OpParameter *PopulateLayerNormParameter(const void *prim) { | ||||
| auto layer_norm_parameter = reinterpret_cast<LayerNormParameter *>(malloc(sizeof(LayerNormParameter))); | |||||
| if (layer_norm_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LayerNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(layer_norm_parameter, 0, sizeof(LayerNormParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| layer_norm_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_LayerNormFusion(); | |||||
| auto value = primitive->value_as_LayerNormFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto param = reinterpret_cast<LayerNormParameter *>(malloc(sizeof(LayerNormParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc LayerNormParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| layer_norm_parameter->epsilon_ = param->epsilon(); | |||||
| layer_norm_parameter->elementwise_affine_ = param->elementwise_affine(); | |||||
| layer_norm_parameter->begin_norm_axis_ = static_cast<int>(param->begin_norm_axis()); | |||||
| layer_norm_parameter->begin_params_axis_ = static_cast<int>(param->begin_params_axis()); | |||||
| return reinterpret_cast<OpParameter *>(layer_norm_parameter); | |||||
| memset(param, 0, sizeof(LayerNormParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->epsilon_ = value->epsilon(); | |||||
| param->elementwise_affine_ = value->elementwise_affine(); | |||||
| param->begin_norm_axis_ = static_cast<int>(value->begin_norm_axis()); | |||||
| param->begin_params_axis_ = static_cast<int>(value->begin_params_axis()); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_LayerNormFusion, PopulateLayerNormParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_LayerNormFusion, PopulateLayerNormParameter, SCHEMA_CUR) | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_LRN; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateLocalResponseNormParameter(const void *prim) { | OpParameter *PopulateLocalResponseNormParameter(const void *prim) { | ||||
| auto *lrn_param = reinterpret_cast<LocalResponseNormParameter *>(malloc(sizeof(LocalResponseNormParameter))); | |||||
| if (lrn_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lrn_param, 0, sizeof(LocalResponseNormParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_LRN(); | auto value = primitive->value_as_LRN(); | ||||
| @@ -33,12 +27,20 @@ OpParameter *PopulateLocalResponseNormParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| lrn_param->op_parameter_.type_ = primitive->value_type(); | |||||
| lrn_param->depth_radius_ = value->depth_radius(); | |||||
| lrn_param->bias_ = value->bias(); | |||||
| lrn_param->alpha_ = value->alpha(); | |||||
| lrn_param->beta_ = value->beta(); | |||||
| return reinterpret_cast<OpParameter *>(lrn_param); | |||||
| auto *param = reinterpret_cast<LocalResponseNormParameter *>(malloc(sizeof(LocalResponseNormParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(LocalResponseNormParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->depth_radius_ = value->depth_radius(); | |||||
| param->bias_ = value->bias(); | |||||
| param->alpha_ = value->alpha(); | |||||
| param->beta_ = value->beta(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_LRN, PopulateLocalResponseNormParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_LRN, PopulateLocalResponseNormParameter, SCHEMA_CUR); | ||||
| @@ -19,26 +19,26 @@ using mindspore::schema::PrimitiveType_LogSoftmax; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateLogSoftmaxParameter(const void *prim) { | OpParameter *PopulateLogSoftmaxParameter(const void *prim) { | ||||
| auto *log_softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||||
| if (log_softmax_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LogSoftmaxParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(log_softmax_param, 0, sizeof(SoftmaxParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| log_softmax_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto prim_log_softmax = primitive->value_as_LogSoftmax(); | |||||
| if (prim_log_softmax == nullptr) { | |||||
| MS_LOG(ERROR) << "prim_log_softmax is nullptr"; | |||||
| auto value = primitive->value_as_LogSoftmax(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SoftmaxParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| log_softmax_param->axis_ = prim_log_softmax->axis(); | |||||
| return reinterpret_cast<OpParameter *>(log_softmax_param); | |||||
| memset(param, 0, sizeof(SoftmaxParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = value->axis(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_LogSoftmax, PopulateLogSoftmaxParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_LogSoftmax, PopulateLogSoftmaxParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_LshProjection; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateLshProjectionParameter(const void *prim) { | OpParameter *PopulateLshProjectionParameter(const void *prim) { | ||||
| auto *lsh_project_param = reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter))); | |||||
| if (lsh_project_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LshProjectionParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lsh_project_param, 0, sizeof(LshProjectionParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_LshProjection(); | auto value = primitive->value_as_LshProjection(); | ||||
| @@ -34,11 +27,19 @@ OpParameter *PopulateLshProjectionParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| lsh_project_param->op_parameter_.type_ = primitive->value_type(); | |||||
| lsh_project_param->lsh_type_ = value->type(); | |||||
| return reinterpret_cast<OpParameter *>(lsh_project_param); | |||||
| auto *param = reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LshProjectionParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(LshProjectionParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->lsh_type_ = value->type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_LshProjection, PopulateLshProjectionParameter, SCHEMA_CUR); | |||||
| REG_POPULATE(PrimitiveType_LshProjection, PopulateLshProjectionParameter, SCHEMA_CUR); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,30 +19,29 @@ using mindspore::schema::PrimitiveType_LSTM; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateLstmParameter(const void *prim) { | OpParameter *PopulateLstmParameter(const void *prim) { | ||||
| auto *lstm_param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter))); | |||||
| if (lstm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LstmParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lstm_param, 0, sizeof(LstmParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| lstm_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_LSTM(); | |||||
| auto value = primitive->value_as_LSTM(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| free(lstm_param); | |||||
| MS_LOG(ERROR) << "get Lstm param nullptr."; | |||||
| MS_LOG(ERROR) << "malloc LstmParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(LstmParameter)); | |||||
| lstm_param->bidirectional_ = param->bidirectional(); | |||||
| lstm_param->zoneout_cell_ = param->zoneout_cell(); | |||||
| lstm_param->zoneout_hidden_ = param->zoneout_hidden(); | |||||
| return reinterpret_cast<OpParameter *>(lstm_param); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->bidirectional_ = value->bidirectional(); | |||||
| param->zoneout_cell_ = value->zoneout_cell(); | |||||
| param->zoneout_hidden_ = value->zoneout_hidden(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_LSTM, PopulateLstmParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_LSTM, PopulateLstmParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,15 +16,10 @@ | |||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/matmul_parameter.h" | #include "nnacl/matmul_parameter.h" | ||||
| using mindspore::schema::PrimitiveType_MatMul; | using mindspore::schema::PrimitiveType_MatMul; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateMatMulParameter(const void *prim) { | OpParameter *PopulateMatMulParameter(const void *prim) { | ||||
| auto *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (matmul_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MatMulParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(matmul_param, 0, sizeof(MatMulParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_MatMul(); | auto value = primitive->value_as_MatMul(); | ||||
| @@ -32,13 +27,22 @@ OpParameter *PopulateMatMulParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| matmul_param->op_parameter_.type_ = primitive->value_type(); | |||||
| matmul_param->b_transpose_ = value->transpose_b(); | |||||
| matmul_param->a_transpose_ = value->transpose_a(); | |||||
| matmul_param->has_bias_ = false; | |||||
| matmul_param->act_type_ = ActType_No; | |||||
| return reinterpret_cast<OpParameter *>(matmul_param); | |||||
| auto *param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MatMulParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(MatMulParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->b_transpose_ = value->transpose_b(); | |||||
| param->a_transpose_ = value->transpose_a(); | |||||
| param->has_bias_ = false; | |||||
| param->act_type_ = ActType_No; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_MatMul, PopulateMatMulParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_MatMul, PopulateMatMulParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,16 +19,18 @@ using mindspore::schema::PrimitiveType_Merge; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateMergeParameter(const void *prim) { | OpParameter *PopulateMergeParameter(const void *prim) { | ||||
| auto *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (merge_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc Merge parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(merge_parameter, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| merge_parameter->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(merge_parameter); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Merge, PopulateMergeParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Merge, PopulateMergeParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,26 +19,26 @@ using mindspore::schema::PrimitiveType_Mfcc; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateMfccParameter(const void *prim) { | OpParameter *PopulateMfccParameter(const void *prim) { | ||||
| auto *arg_param = reinterpret_cast<MfccParameter *>(malloc(sizeof(MfccParameter))); | |||||
| if (arg_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MfccParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arg_param, 0, sizeof(MfccParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| arg_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_Mfcc(); | |||||
| auto value = primitive->value_as_Mfcc(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<MfccParameter *>(malloc(sizeof(MfccParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc MfccParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| arg_param->dct_coeff_num_ = param->dct_coeff_num(); | |||||
| return reinterpret_cast<OpParameter *>(arg_param); | |||||
| memset(param, 0, sizeof(MfccParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->dct_coeff_num_ = value->dct_coeff_num(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Mfcc, PopulateMfccParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Mfcc, PopulateMfccParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,19 +20,19 @@ using mindspore::schema::PrimitiveType_MulFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateMulParameter(const void *prim) { | OpParameter *PopulateMulParameter(const void *prim) { | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); | ArithmeticParameter *param = PopulateArithmeticCommonPara(prim); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_MulFusion, PopulateMulParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_MulFusion, PopulateMulParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_NonMaxSuppression; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateNonMaxSuppressionParameter(const void *prim) { | OpParameter *PopulateNonMaxSuppressionParameter(const void *prim) { | ||||
| auto *param = reinterpret_cast<NMSParameter *>(malloc(sizeof(NMSParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(NMSParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_NonMaxSuppression(); | auto value = primitive->value_as_NonMaxSuppression(); | ||||
| @@ -33,6 +27,14 @@ OpParameter *PopulateNonMaxSuppressionParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *param = reinterpret_cast<NMSParameter *>(malloc(sizeof(NMSParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc NMSParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(NMSParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| param->center_point_box_ = value->center_point_box(); | param->center_point_box_ = value->center_point_box(); | ||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_OneHot; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateOneHotParameter(const void *prim) { | OpParameter *PopulateOneHotParameter(const void *prim) { | ||||
| auto *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter))); | |||||
| if (one_hot_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OneHotParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(one_hot_param, 0, sizeof(OneHotParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_OneHot(); | auto value = primitive->value_as_OneHot(); | ||||
| @@ -34,10 +27,19 @@ OpParameter *PopulateOneHotParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| one_hot_param->op_parameter_.type_ = primitive->value_type(); | |||||
| one_hot_param->axis_ = value->axis(); | |||||
| return reinterpret_cast<OpParameter *>(one_hot_param); | |||||
| auto *param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OneHotParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(OneHotParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = value->axis(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_OneHot, PopulateOneHotParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_OneHot, PopulateOneHotParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,14 +19,16 @@ using mindspore::schema::PrimitiveType_OnesLike; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateOnesLikeParameter(const void *prim) { | OpParameter *PopulateOnesLikeParameter(const void *prim) { | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc OnesLike Parameter failed."; | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(param, 0, sizeof(OpParameter)); | memset(param, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| param->type_ = primitive->value_type(); | param->type_ = primitive->value_type(); | ||||
| return param; | return param; | ||||
| } | } | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_PReLUFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulatePReLUParameter(const void *prim) { | OpParameter *PopulatePReLUParameter(const void *prim) { | ||||
| PReluParameter *param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PReluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PReluParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_PReLUFusion(); | auto value = primitive->value_as_PReLUFusion(); | ||||
| @@ -33,10 +27,19 @@ OpParameter *PopulatePReLUParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PReluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PReluParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | param->op_parameter_.type_ = primitive->value_type(); | ||||
| param->channelShared = value->channel_shared(); | param->channelShared = value->channel_shared(); | ||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_PReLUFusion, PopulatePReLUParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_PReLUFusion, PopulatePReLUParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_PadFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulatePadParameter(const void *prim) { | OpParameter *PopulatePadParameter(const void *prim) { | ||||
| auto *pad_param = reinterpret_cast<PadParameter *>(malloc(sizeof(PadParameter))); | |||||
| if (pad_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PadParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(pad_param, 0, sizeof(PadParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_PadFusion(); | auto value = primitive->value_as_PadFusion(); | ||||
| @@ -33,11 +27,20 @@ OpParameter *PopulatePadParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pad_param->op_parameter_.type_ = primitive->value_type(); | |||||
| pad_param->pad_mode_ = value->padding_mode(); | |||||
| pad_param->constant_value_ = value->constant_value(); | |||||
| return reinterpret_cast<OpParameter *>(pad_param); | |||||
| auto *param = reinterpret_cast<PadParameter *>(malloc(sizeof(PadParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PadParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PadParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->pad_mode_ = value->padding_mode(); | |||||
| param->constant_value_ = value->constant_value(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_PadFusion, PopulatePadParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_PadFusion, PopulatePadParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,14 +19,7 @@ using mindspore::schema::PrimitiveType_PartialFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulatePartialParameter(const void *prim) { | OpParameter *PopulatePartialParameter(const void *prim) { | ||||
| auto *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter))); | |||||
| if (partial_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc partial parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(partial_parameter, 0, sizeof(PartialParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_PartialFusion(); | auto value = primitive->value_as_PartialFusion(); | ||||
| @@ -34,11 +27,19 @@ OpParameter *PopulatePartialParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| partial_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| partial_parameter->sub_graph_index_ = value->sub_graph_index(); | |||||
| return reinterpret_cast<OpParameter *>(partial_parameter); | |||||
| auto *param = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc partial parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PartialParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->sub_graph_index_ = value->sub_graph_index(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_PartialFusion, PopulatePartialParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_PartialFusion, PopulatePartialParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,155 +20,160 @@ using mindspore::schema::PrimitiveType_MaxPoolFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateAvgPoolParameter(const void *primitive) { | OpParameter *PopulateAvgPoolParameter(const void *primitive) { | ||||
| auto *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter))); | |||||
| if (pooling_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PoolingParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(pooling_param, 0, sizeof(PoolingParameter)); | |||||
| auto pooling_prim = static_cast<const schema::Primitive *>(primitive); | auto pooling_prim = static_cast<const schema::Primitive *>(primitive); | ||||
| MS_ASSERT(pooling_prim != nullptr); | MS_ASSERT(pooling_prim != nullptr); | ||||
| pooling_param->op_parameter_.type_ = pooling_prim->value_type(); | |||||
| auto pooling_primitive = pooling_prim->value_as_AvgPoolFusion(); | |||||
| if (pooling_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "pooling_primitive is nullptr"; | |||||
| auto value = pooling_prim->value_as_AvgPoolFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PoolingParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pooling_param->pool_mode_ = PoolMode_AvgPool; | |||||
| pooling_param->global_ = pooling_primitive->global(); | |||||
| auto strides = pooling_primitive->strides(); | |||||
| memset(param, 0, sizeof(PoolingParameter)); | |||||
| param->op_parameter_.type_ = pooling_prim->value_type(); | |||||
| param->pool_mode_ = PoolMode_AvgPool; | |||||
| param->global_ = value->global(); | |||||
| auto strides = value->strides(); | |||||
| if (strides == nullptr) { | if (strides == nullptr) { | ||||
| MS_LOG(ERROR) << "strides is nullptr"; | MS_LOG(ERROR) << "strides is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pooling_param->stride_w_ = static_cast<int>(*(strides->begin() + 1)); | |||||
| pooling_param->stride_h_ = static_cast<int>(*(strides->begin())); | |||||
| auto pad = pooling_primitive->pad(); | |||||
| param->stride_w_ = static_cast<int>(*(strides->begin() + 1)); | |||||
| param->stride_h_ = static_cast<int>(*(strides->begin())); | |||||
| auto pad = value->pad(); | |||||
| if (pad != nullptr) { | if (pad != nullptr) { | ||||
| pooling_param->pad_u_ = static_cast<int>(*(pad->begin())); | |||||
| pooling_param->pad_d_ = static_cast<int>(*(pad->begin() + 1)); | |||||
| pooling_param->pad_l_ = static_cast<int>(*(pad->begin() + 2)); | |||||
| pooling_param->pad_r_ = static_cast<int>(*(pad->begin() + 3)); | |||||
| param->pad_u_ = static_cast<int>(*(pad->begin())); | |||||
| param->pad_d_ = static_cast<int>(*(pad->begin() + 1)); | |||||
| param->pad_l_ = static_cast<int>(*(pad->begin() + 2)); | |||||
| param->pad_r_ = static_cast<int>(*(pad->begin() + 3)); | |||||
| } | } | ||||
| if (!pooling_param->global_) { | |||||
| auto kernel_size = pooling_primitive->kernel_size(); | |||||
| if (!param->global_) { | |||||
| auto kernel_size = value->kernel_size(); | |||||
| if (kernel_size == nullptr) { | if (kernel_size == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel_size is nullptr"; | MS_LOG(ERROR) << "kernel_size is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pooling_param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| pooling_param->window_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| param->window_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| } | } | ||||
| auto round_mode = pooling_primitive->round_mode(); | |||||
| auto round_mode = value->round_mode(); | |||||
| switch (round_mode) { | switch (round_mode) { | ||||
| case schema::RoundMode_FLOOR: | case schema::RoundMode_FLOOR: | ||||
| pooling_param->round_mode_ = RoundMode_Floor; | |||||
| param->round_mode_ = RoundMode_Floor; | |||||
| break; | break; | ||||
| case schema::RoundMode_CEIL: | case schema::RoundMode_CEIL: | ||||
| pooling_param->round_mode_ = RoundMode_Ceil; | |||||
| param->round_mode_ = RoundMode_Ceil; | |||||
| break; | break; | ||||
| default: | default: | ||||
| pooling_param->round_mode_ = RoundMode_No; | |||||
| param->round_mode_ = RoundMode_No; | |||||
| break; | break; | ||||
| } | } | ||||
| if (pooling_primitive->activation_type() == schema::ActivationType_RELU) { | |||||
| pooling_param->act_type_ = ActType_Relu; | |||||
| } else if (pooling_primitive->activation_type() == schema::ActivationType_RELU6) { | |||||
| pooling_param->act_type_ = ActType_Relu6; | |||||
| if (value->activation_type() == schema::ActivationType_RELU) { | |||||
| param->act_type_ = ActType_Relu; | |||||
| } else if (value->activation_type() == schema::ActivationType_RELU6) { | |||||
| param->act_type_ = ActType_Relu6; | |||||
| } else { | } else { | ||||
| pooling_param->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| } | } | ||||
| switch (pooling_primitive->pad_mode()) { | |||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | case schema::PadMode_SAME: | ||||
| pooling_param->pad_mode_ = Pad_same; | |||||
| param->pad_mode_ = Pad_same; | |||||
| break; | break; | ||||
| case schema::PadMode_VALID: | case schema::PadMode_VALID: | ||||
| pooling_param->pad_mode_ = Pad_valid; | |||||
| param->pad_mode_ = Pad_valid; | |||||
| break; | break; | ||||
| default: | default: | ||||
| pooling_param->pad_mode_ = Pad_pad; | |||||
| param->pad_mode_ = Pad_pad; | |||||
| break; | break; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(pooling_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| OpParameter *PopulateMaxPoolParameter(const void *primitive) { | OpParameter *PopulateMaxPoolParameter(const void *primitive) { | ||||
| auto *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter))); | |||||
| if (pooling_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PoolingParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(pooling_param, 0, sizeof(PoolingParameter)); | |||||
| auto pooling_prim = static_cast<const schema::Primitive *>(primitive); | auto pooling_prim = static_cast<const schema::Primitive *>(primitive); | ||||
| MS_ASSERT(pooling_prim != nullptr); | MS_ASSERT(pooling_prim != nullptr); | ||||
| pooling_param->op_parameter_.type_ = pooling_prim->value_type(); | |||||
| auto max_pool_prim = pooling_prim->value_as_MaxPoolFusion(); | |||||
| if (max_pool_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "max_pool_prim is nullptr"; | |||||
| auto value = pooling_prim->value_as_MaxPoolFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PoolingParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pooling_param->pool_mode_ = PoolMode_MaxPool; | |||||
| pooling_param->global_ = max_pool_prim->global(); | |||||
| if (!pooling_param->global_) { | |||||
| auto kernel_size = max_pool_prim->kernel_size(); | |||||
| auto strides = max_pool_prim->strides(); | |||||
| memset(param, 0, sizeof(PoolingParameter)); | |||||
| param->op_parameter_.type_ = pooling_prim->value_type(); | |||||
| param->pool_mode_ = PoolMode_MaxPool; | |||||
| param->global_ = value->global(); | |||||
| if (!param->global_) { | |||||
| auto kernel_size = value->kernel_size(); | |||||
| auto strides = value->strides(); | |||||
| if (kernel_size == nullptr || strides == nullptr) { | if (kernel_size == nullptr || strides == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel_size or strides is nullptr"; | MS_LOG(ERROR) << "kernel_size or strides is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pooling_param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| pooling_param->window_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| pooling_param->stride_w_ = static_cast<int>(*(strides->begin() + 1)); | |||||
| pooling_param->stride_h_ = static_cast<int>(*(strides->begin())); | |||||
| auto pad = max_pool_prim->pad(); | |||||
| param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||||
| param->window_h_ = static_cast<int>(*(kernel_size->begin())); | |||||
| param->stride_w_ = static_cast<int>(*(strides->begin() + 1)); | |||||
| param->stride_h_ = static_cast<int>(*(strides->begin())); | |||||
| auto pad = value->pad(); | |||||
| if (pad != nullptr) { | if (pad != nullptr) { | ||||
| pooling_param->pad_u_ = static_cast<int>(*(pad->begin())); | |||||
| pooling_param->pad_d_ = static_cast<int>(*(pad->begin() + 1)); | |||||
| pooling_param->pad_l_ = static_cast<int>(*(pad->begin() + 2)); | |||||
| pooling_param->pad_r_ = static_cast<int>(*(pad->begin() + 3)); | |||||
| param->pad_u_ = static_cast<int>(*(pad->begin())); | |||||
| param->pad_d_ = static_cast<int>(*(pad->begin() + 1)); | |||||
| param->pad_l_ = static_cast<int>(*(pad->begin() + 2)); | |||||
| param->pad_r_ = static_cast<int>(*(pad->begin() + 3)); | |||||
| } | } | ||||
| } | } | ||||
| auto round_mode = max_pool_prim->round_mode(); | |||||
| auto round_mode = value->round_mode(); | |||||
| switch (round_mode) { | switch (round_mode) { | ||||
| case schema::RoundMode_FLOOR: | case schema::RoundMode_FLOOR: | ||||
| pooling_param->round_mode_ = RoundMode_Floor; | |||||
| param->round_mode_ = RoundMode_Floor; | |||||
| break; | break; | ||||
| case schema::RoundMode_CEIL: | case schema::RoundMode_CEIL: | ||||
| pooling_param->round_mode_ = RoundMode_Ceil; | |||||
| param->round_mode_ = RoundMode_Ceil; | |||||
| break; | break; | ||||
| default: | default: | ||||
| pooling_param->round_mode_ = RoundMode_No; | |||||
| param->round_mode_ = RoundMode_No; | |||||
| break; | break; | ||||
| } | } | ||||
| if (max_pool_prim->activation_type() == schema::ActivationType_RELU) { | |||||
| pooling_param->act_type_ = ActType_Relu; | |||||
| } else if (max_pool_prim->activation_type() == schema::ActivationType_RELU6) { | |||||
| pooling_param->act_type_ = ActType_Relu6; | |||||
| if (value->activation_type() == schema::ActivationType_RELU) { | |||||
| param->act_type_ = ActType_Relu; | |||||
| } else if (value->activation_type() == schema::ActivationType_RELU6) { | |||||
| param->act_type_ = ActType_Relu6; | |||||
| } else { | } else { | ||||
| pooling_param->act_type_ = ActType_No; | |||||
| param->act_type_ = ActType_No; | |||||
| } | } | ||||
| switch (max_pool_prim->pad_mode()) { | |||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | case schema::PadMode_SAME: | ||||
| pooling_param->pad_mode_ = Pad_same; | |||||
| param->pad_mode_ = Pad_same; | |||||
| break; | break; | ||||
| case schema::PadMode_VALID: | case schema::PadMode_VALID: | ||||
| pooling_param->pad_mode_ = Pad_valid; | |||||
| param->pad_mode_ = Pad_valid; | |||||
| break; | break; | ||||
| default: | default: | ||||
| pooling_param->pad_mode_ = Pad_pad; | |||||
| param->pad_mode_ = Pad_pad; | |||||
| break; | break; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(pooling_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR) | ||||
| REG_POPULATE(PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR) | ||||
| @@ -19,27 +19,27 @@ using mindspore::schema::PrimitiveType_PowFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulatePowerParameter(const void *prim) { | OpParameter *PopulatePowerParameter(const void *prim) { | ||||
| auto *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter))); | |||||
| if (power_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PowerParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(power_param, 0, sizeof(PowerParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| power_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto power_prim = primitive->value_as_PowFusion(); | |||||
| if (power_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "power_prim is nullptr"; | |||||
| auto value = primitive->value_as_PowFusion(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PowerParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| power_param->scale_ = power_prim->scale(); | |||||
| power_param->shift_ = power_prim->shift(); | |||||
| return reinterpret_cast<OpParameter *>(power_param); | |||||
| memset(param, 0, sizeof(PowerParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->scale_ = value->scale(); | |||||
| param->shift_ = value->shift(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_PowFusion, PopulatePowerParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_PowFusion, PopulatePowerParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,13 +20,6 @@ using mindspore::schema::PrimitiveType_PriorBox; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulatePriorBoxParameter(const void *prim) { | OpParameter *PopulatePriorBoxParameter(const void *prim) { | ||||
| auto *prior_box_param = reinterpret_cast<PriorBoxParameter *>(malloc(sizeof(PriorBoxParameter))); | |||||
| if (prior_box_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PriorBoxParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(prior_box_param, 0, sizeof(PriorBoxParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_PriorBox(); | auto value = primitive->value_as_PriorBox(); | ||||
| @@ -34,67 +27,80 @@ OpParameter *PopulatePriorBoxParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| prior_box_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto *param = reinterpret_cast<PriorBoxParameter *>(malloc(sizeof(PriorBoxParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PriorBoxParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PriorBoxParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto min_sizes = value->min_sizes(); | auto min_sizes = value->min_sizes(); | ||||
| if (min_sizes == nullptr) { | if (min_sizes == nullptr) { | ||||
| MS_LOG(ERROR) << "min_sizes is nullptr"; | MS_LOG(ERROR) << "min_sizes is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (min_sizes->size() > MAX_SHAPE_SIZE) { | if (min_sizes->size() > MAX_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << min_sizes->size(); | MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << min_sizes->size(); | ||||
| free(prior_box_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| prior_box_param->min_sizes_size = min_sizes->size(); | |||||
| memcpy(prior_box_param->min_sizes, min_sizes->data(), min_sizes->size() * sizeof(int32_t)); | |||||
| param->min_sizes_size = min_sizes->size(); | |||||
| memcpy(param->min_sizes, min_sizes->data(), min_sizes->size() * sizeof(int32_t)); | |||||
| auto max_sizes = value->max_sizes(); | auto max_sizes = value->max_sizes(); | ||||
| if (max_sizes == nullptr) { | if (max_sizes == nullptr) { | ||||
| MS_LOG(ERROR) << "max_sizes is nullptr"; | MS_LOG(ERROR) << "max_sizes is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (max_sizes->size() > MAX_SHAPE_SIZE) { | if (max_sizes->size() > MAX_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << max_sizes->size(); | MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << MAX_SHAPE_SIZE << ", got " << max_sizes->size(); | ||||
| free(prior_box_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| prior_box_param->max_sizes_size = max_sizes->size(); | |||||
| memcpy(prior_box_param->max_sizes, max_sizes->data(), max_sizes->size() * sizeof(int32_t)); | |||||
| param->max_sizes_size = max_sizes->size(); | |||||
| memcpy(param->max_sizes, max_sizes->data(), max_sizes->size() * sizeof(int32_t)); | |||||
| auto aspect_ratios = value->aspect_ratios(); | auto aspect_ratios = value->aspect_ratios(); | ||||
| if (aspect_ratios == nullptr) { | if (aspect_ratios == nullptr) { | ||||
| MS_LOG(ERROR) << "aspect_ratios is nullptr"; | MS_LOG(ERROR) << "aspect_ratios is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (aspect_ratios->size() > MAX_SHAPE_SIZE) { | if (aspect_ratios->size() > MAX_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << MAX_SHAPE_SIZE << ", got " | MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << MAX_SHAPE_SIZE << ", got " | ||||
| << aspect_ratios->size(); | << aspect_ratios->size(); | ||||
| free(prior_box_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| prior_box_param->aspect_ratios_size = aspect_ratios->size(); | |||||
| memcpy(prior_box_param->aspect_ratios, aspect_ratios->data(), aspect_ratios->size() * sizeof(float)); | |||||
| param->aspect_ratios_size = aspect_ratios->size(); | |||||
| memcpy(param->aspect_ratios, aspect_ratios->data(), aspect_ratios->size() * sizeof(float)); | |||||
| auto variances = value->variances(); | auto variances = value->variances(); | ||||
| if (variances == nullptr) { | if (variances == nullptr) { | ||||
| MS_LOG(ERROR) << "variances is nullptr"; | MS_LOG(ERROR) << "variances is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (variances->size() != COMM_SHAPE_SIZE) { | if (variances->size() != COMM_SHAPE_SIZE) { | ||||
| MS_LOG(ERROR) << "PriorBox variances size should be " << COMM_SHAPE_SIZE << ", got " << variances->size(); | MS_LOG(ERROR) << "PriorBox variances size should be " << COMM_SHAPE_SIZE << ", got " << variances->size(); | ||||
| free(prior_box_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memcpy(prior_box_param->variances, variances->data(), COMM_SHAPE_SIZE * sizeof(float)); | |||||
| prior_box_param->flip = value->flip(); | |||||
| prior_box_param->clip = value->clip(); | |||||
| prior_box_param->offset = value->offset(); | |||||
| prior_box_param->image_size_h = value->image_size_h(); | |||||
| prior_box_param->image_size_w = value->image_size_w(); | |||||
| prior_box_param->step_h = value->step_h(); | |||||
| prior_box_param->step_w = value->step_w(); | |||||
| return reinterpret_cast<OpParameter *>(prior_box_param); | |||||
| memcpy(param->variances, variances->data(), COMM_SHAPE_SIZE * sizeof(float)); | |||||
| param->flip = value->flip(); | |||||
| param->clip = value->clip(); | |||||
| param->offset = value->offset(); | |||||
| param->image_size_h = value->image_size_h(); | |||||
| param->image_size_w = value->image_size_w(); | |||||
| param->step_h = value->step_h(); | |||||
| param->step_w = value->step_w(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_PriorBox, PopulatePriorBoxParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_PriorBox, PopulatePriorBoxParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_QuantDTypeCast; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateQuantDTypeCastParameter(const void *prim) { | OpParameter *PopulateQuantDTypeCastParameter(const void *prim) { | ||||
| auto *parameter = reinterpret_cast<QuantDTypeCastParameter *>(malloc(sizeof(QuantDTypeCastParameter))); | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc QuantDTypeCastParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(parameter, 0, sizeof(QuantDTypeCastParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_QuantDTypeCast(); | auto value = primitive->value_as_QuantDTypeCast(); | ||||
| @@ -33,12 +27,20 @@ OpParameter *PopulateQuantDTypeCastParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| parameter->srcT = value->src_t(); | |||||
| parameter->dstT = value->dst_t(); | |||||
| return reinterpret_cast<OpParameter *>(parameter); | |||||
| auto *param = reinterpret_cast<QuantDTypeCastParameter *>(malloc(sizeof(QuantDTypeCastParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc QuantDTypeCastParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(QuantDTypeCastParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->srcT = value->src_t(); | |||||
| param->dstT = value->dst_t(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_QuantDTypeCast, PopulateQuantDTypeCastParameter, SCHEMA_CUR); | |||||
| REG_POPULATE(PrimitiveType_QuantDTypeCast, PopulateQuantDTypeCastParameter, SCHEMA_CUR); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,25 +22,28 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| OpParameter *PopulateRandomStandardNormalParameter(const void *prim) { | OpParameter *PopulateRandomStandardNormalParameter(const void *prim) { | ||||
| auto *random_parameter = reinterpret_cast<RandomParam *>(malloc(sizeof(RandomParam))); | |||||
| if (random_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc RandomStandardNormal parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(random_parameter, 0, sizeof(RandomParam)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| random_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_RandomStandardNormal(); | |||||
| auto value = primitive->value_as_RandomStandardNormal(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<RandomParam *>(malloc(sizeof(RandomParam))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc RandomParam failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| random_parameter->seed_ = param->seed(); | |||||
| random_parameter->seed2_ = param->seed2(); | |||||
| return reinterpret_cast<OpParameter *>(random_parameter); | |||||
| memset(param, 0, sizeof(RandomParam)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->seed_ = value->seed(); | |||||
| param->seed2_ = value->seed2(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| REG_POPULATE(PrimitiveType_RandomStandardNormal, PopulateRandomStandardNormalParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_RandomStandardNormal, PopulateRandomStandardNormalParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,29 +19,30 @@ using mindspore::schema::PrimitiveType_Range; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateRangeParameter(const void *prim) { | OpParameter *PopulateRangeParameter(const void *prim) { | ||||
| auto *range_param = reinterpret_cast<RangeParameter *>(malloc(sizeof(RangeParameter))); | |||||
| if (range_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc RangeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(range_param, 0, sizeof(RangeParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| range_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto range_prim = primitive->value_as_Range(); | |||||
| if (range_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "range_prim is nullptr"; | |||||
| auto value = primitive->value_as_Range(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| range_param->start_ = range_prim->start(); | |||||
| range_param->limit_ = range_prim->limit(); | |||||
| range_param->delta_ = range_prim->delta(); | |||||
| range_param->dType_ = range_prim->d_type(); | |||||
| return reinterpret_cast<OpParameter *>(range_param); | |||||
| auto *param = reinterpret_cast<RangeParameter *>(malloc(sizeof(RangeParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc RangeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(RangeParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->start_ = value->start(); | |||||
| param->limit_ = value->limit(); | |||||
| param->delta_ = value->delta(); | |||||
| param->dType_ = value->d_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Range, PopulateRangeParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Range, PopulateRangeParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,20 +18,20 @@ using mindspore::schema::PrimitiveType_Rank; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateRankParameter(const void *prim) { | OpParameter *PopulateRankParameter(const void *prim) { | ||||
| auto *rank_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (rank_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc RankParameter failed."; | MS_LOG(ERROR) << "malloc RankParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(rank_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| rank_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(rank_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Rank, PopulateRankParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Rank, PopulateRankParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -13,19 +13,13 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <memory> | |||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/reduce_parameter.h" | #include "nnacl/reduce_parameter.h" | ||||
| using mindspore::schema::PrimitiveType_ReduceFusion; | using mindspore::schema::PrimitiveType_ReduceFusion; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateReduceParameter(const void *prim) { | OpParameter *PopulateReduceParameter(const void *prim) { | ||||
| auto *reduce_param = reinterpret_cast<ReduceParameter *>(malloc(sizeof(ReduceParameter))); | |||||
| if (reduce_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReduceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(reduce_param, 0, sizeof(ReduceParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_ReduceFusion(); | auto value = primitive->value_as_ReduceFusion(); | ||||
| @@ -33,15 +27,22 @@ OpParameter *PopulateReduceParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| reduce_param->op_parameter_.type_ = primitive->value_type(); | |||||
| reduce_param->keep_dims_ = value->keep_dims(); | |||||
| reduce_param->reduce_to_end_ = value->reduce_to_end(); | |||||
| reduce_param->coeff = value->coeff(); | |||||
| reduce_param->mode_ = static_cast<int>(value->mode()); | |||||
| return reinterpret_cast<OpParameter *>(reduce_param); | |||||
| auto *param = reinterpret_cast<ReduceParameter *>(malloc(sizeof(ReduceParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReduceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ReduceParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->keep_dims_ = value->keep_dims(); | |||||
| param->reduce_to_end_ = value->reduce_to_end(); | |||||
| param->coeff = value->coeff(); | |||||
| param->mode_ = static_cast<int>(value->mode()); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_ReduceFusion, PopulateReduceParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ReduceFusion, PopulateReduceParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_Reshape; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateReshapeParameter(const void *prim) { | OpParameter *PopulateReshapeParameter(const void *prim) { | ||||
| auto *reshape_param = reinterpret_cast<ReshapeParameter *>(malloc(sizeof(ReshapeParameter))); | |||||
| if (reshape_param == nullptr) { | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<ReshapeParameter *>(malloc(sizeof(ReshapeParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReshapeParameter failed."; | MS_LOG(ERROR) << "malloc ReshapeParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(reshape_param, 0, sizeof(ReshapeParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| reshape_param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(reshape_param); | |||||
| memset(param, 0, sizeof(ReshapeParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Reshape, PopulateReshapeParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Reshape, PopulateReshapeParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_Resize; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateResizeParameter(const void *prim) { | OpParameter *PopulateResizeParameter(const void *prim) { | ||||
| auto *resize_param = reinterpret_cast<ResizeParameter *>(malloc(sizeof(ResizeParameter))); | |||||
| if (resize_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ResizeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(resize_param, 0, sizeof(ResizeParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_Resize(); | auto value = primitive->value_as_Resize(); | ||||
| @@ -33,18 +27,24 @@ OpParameter *PopulateResizeParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| resize_param->op_parameter_.type_ = primitive->value_type(); | |||||
| resize_param->method_ = static_cast<int>(value->method()); | |||||
| resize_param->new_height_ = value->new_height(); | |||||
| resize_param->new_width_ = value->new_width(); | |||||
| resize_param->coordinate_transform_mode_ = value->coordinate_transform_mode(); | |||||
| resize_param->preserve_aspect_ratio_ = value->preserve_aspect_ratio(); | |||||
| resize_param->cubic_coeff_ = value->cubic_coeff(); | |||||
| return reinterpret_cast<OpParameter *>(resize_param); | |||||
| auto *param = reinterpret_cast<ResizeParameter *>(malloc(sizeof(ResizeParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ResizeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ResizeParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->method_ = static_cast<int>(value->method()); | |||||
| param->new_height_ = value->new_height(); | |||||
| param->new_width_ = value->new_width(); | |||||
| param->coordinate_transform_mode_ = value->coordinate_transform_mode(); | |||||
| param->preserve_aspect_ratio_ = value->preserve_aspect_ratio(); | |||||
| param->cubic_coeff_ = value->cubic_coeff(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Resize, PopulateResizeParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Resize, PopulateResizeParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_ReverseV2; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateReverseParameter(const void *prim) { | OpParameter *PopulateReverseParameter(const void *prim) { | ||||
| auto *reverse_param = reinterpret_cast<ReverseParameter *>(malloc(sizeof(ReverseParameter))); | |||||
| if (reverse_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReverseParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(reverse_param, 0, sizeof(ReverseParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_ReverseV2(); | auto value = primitive->value_as_ReverseV2(); | ||||
| @@ -33,19 +27,27 @@ OpParameter *PopulateReverseParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| reverse_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto *param = reinterpret_cast<ReverseParameter *>(malloc(sizeof(ReverseParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReverseParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ReverseParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto flatAxis = value->axis(); | auto flatAxis = value->axis(); | ||||
| if (flatAxis == nullptr) { | if (flatAxis == nullptr) { | ||||
| MS_LOG(ERROR) << "flatAxis is nullptr"; | MS_LOG(ERROR) << "flatAxis is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| reverse_param->num_axis_ = flatAxis->size(); | |||||
| param->num_axis_ = flatAxis->size(); | |||||
| int i = 0; | int i = 0; | ||||
| for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { | |||||
| reverse_param->axis_[i++] = *iter; | |||||
| for (auto flatAxi : *flatAxis) { | |||||
| param->axis_[i++] = static_cast<int>(flatAxi); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(reverse_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_ReverseV2, PopulateReverseParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ReverseV2, PopulateReverseParameter, SCHEMA_CUR) | ||||
| @@ -19,29 +19,28 @@ using mindspore::schema::PrimitiveType_ReverseSequence; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateReverseSequenceParameter(const void *prim) { | OpParameter *PopulateReverseSequenceParameter(const void *prim) { | ||||
| auto *reverse_sequence_param = reinterpret_cast<ReverseSequenceParameter *>(malloc(sizeof(ReverseSequenceParameter))); | |||||
| if (reverse_sequence_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReverseSequenceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(reverse_sequence_param, 0, sizeof(ReverseSequenceParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto param = primitive->value_as_ReverseSequence(); | |||||
| auto value = primitive->value_as_ReverseSequence(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<ReverseSequenceParameter *>(malloc(sizeof(ReverseSequenceParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc ReverseSequenceParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| reverse_sequence_param->op_parameter_.type_ = primitive->value_type(); | |||||
| reverse_sequence_param->seq_axis_ = static_cast<int>(param->seq_dim()); | |||||
| reverse_sequence_param->batch_axis_ = static_cast<int>(param->batch_dim()); | |||||
| return reinterpret_cast<OpParameter *>(reverse_sequence_param); | |||||
| memset(param, 0, sizeof(ReverseSequenceParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->seq_axis_ = static_cast<int>(value->seq_dim()); | |||||
| param->batch_axis_ = static_cast<int>(value->batch_dim()); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ReverseSequence, PopulateReverseSequenceParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_ReverseSequence, PopulateReverseSequenceParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,29 +19,28 @@ using mindspore::schema::PrimitiveType_ROIPooling; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateROIPoolingParameter(const void *prim) { | OpParameter *PopulateROIPoolingParameter(const void *prim) { | ||||
| auto *roi_param = reinterpret_cast<ROIPoolingParameter *>(malloc(sizeof(ROIPoolingParameter))); | |||||
| if (roi_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ROIPoolingParameter failed."; | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto value = primitive->value_as_ROIPooling(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(roi_param, 0, sizeof(ROIPoolingParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| roi_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto roi_prim = primitive->value_as_ROIPooling(); | |||||
| if (roi_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "roi_prim is nullptr"; | |||||
| auto *param = reinterpret_cast<ROIPoolingParameter *>(malloc(sizeof(ROIPoolingParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ROIPoolingParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| roi_param->pooledH_ = roi_prim->pooled_h(); | |||||
| roi_param->pooledW_ = roi_prim->pooled_w(); | |||||
| roi_param->scale_ = roi_prim->scale(); | |||||
| return reinterpret_cast<OpParameter *>(roi_param); | |||||
| memset(param, 0, sizeof(ROIPoolingParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->pooledH_ = value->pooled_h(); | |||||
| param->pooledW_ = value->pooled_w(); | |||||
| param->scale_ = value->scale(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ROIPooling, PopulateROIPoolingParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ROIPooling, PopulateROIPoolingParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,14 +19,7 @@ using mindspore::schema::PrimitiveType_ScaleFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateScaleParameter(const void *prim) { | OpParameter *PopulateScaleParameter(const void *prim) { | ||||
| auto *scale_param = reinterpret_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter))); | |||||
| if (scale_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ScaleParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(scale_param, 0, sizeof(ScaleParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_ScaleFusion(); | auto value = primitive->value_as_ScaleFusion(); | ||||
| @@ -34,12 +27,19 @@ OpParameter *PopulateScaleParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| scale_param->op_parameter_.type_ = primitive->value_type(); | |||||
| scale_param->axis_ = value->axis(); | |||||
| scale_param->activation_type_ = value->activation_type(); | |||||
| return reinterpret_cast<OpParameter *>(scale_param); | |||||
| auto *param = reinterpret_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ScaleParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ScaleParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = value->axis(); | |||||
| param->activation_type_ = value->activation_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ScaleFusion, PopulateScaleParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ScaleFusion, PopulateScaleParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -18,20 +18,21 @@ using mindspore::schema::PrimitiveType_ScatterNd; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateScatterNDParameter(const void *prim) { | OpParameter *PopulateScatterNDParameter(const void *prim) { | ||||
| auto *scatter_nd_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (scatter_nd_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ScatterNDParameter failed."; | MS_LOG(ERROR) << "malloc ScatterNDParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(scatter_nd_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| scatter_nd_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(scatter_nd_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_ScatterNd, PopulateScatterNDParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_ScatterNd, PopulateScatterNDParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,20 +20,20 @@ using mindspore::schema::PrimitiveType_Shape; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateShapeParameter(const void *prim) { | OpParameter *PopulateShapeParameter(const void *prim) { | ||||
| auto *shape_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (shape_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ShapeParameter failed."; | MS_LOG(ERROR) << "malloc ShapeParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(shape_param, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| shape_param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(shape_param); | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Shape, PopulateShapeParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Shape, PopulateShapeParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_SkipGram; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSkipGramParameter(const void *prim) { | OpParameter *PopulateSkipGramParameter(const void *prim) { | ||||
| auto *skipGramParameter = reinterpret_cast<SkipGramParameter *>(malloc(sizeof(SkipGramParameter))); | |||||
| if (skipGramParameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SkipGramParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(skipGramParameter, 0, sizeof(SkipGramParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_SkipGram(); | auto value = primitive->value_as_SkipGram(); | ||||
| @@ -33,12 +27,21 @@ OpParameter *PopulateSkipGramParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| skipGramParameter->op_parameter_.type_ = primitive->value_type(); | |||||
| skipGramParameter->ngram_size = value->ngram_size(); | |||||
| skipGramParameter->max_skip_size = value->max_skip_size(); | |||||
| skipGramParameter->include_all_ngrams = value->include_all_grams(); | |||||
| return reinterpret_cast<OpParameter *>(skipGramParameter); | |||||
| auto *param = reinterpret_cast<SkipGramParameter *>(malloc(sizeof(SkipGramParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SkipGramParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SkipGramParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->ngram_size = value->ngram_size(); | |||||
| param->max_skip_size = value->max_skip_size(); | |||||
| param->include_all_ngrams = value->include_all_grams(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_SkipGram, PopulateSkipGramParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_SkipGram, PopulateSkipGramParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_SliceFusion; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSliceParameter(const void *prim) { | OpParameter *PopulateSliceParameter(const void *prim) { | ||||
| auto *slice_param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter))); | |||||
| if (slice_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SliceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(slice_param, 0, sizeof(SliceParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_SliceFusion(); | auto value = primitive->value_as_SliceFusion(); | ||||
| @@ -33,17 +27,27 @@ OpParameter *PopulateSliceParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| slice_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto *param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SliceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SliceParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto axes = value->axes(); | auto axes = value->axes(); | ||||
| if (axes == nullptr) { | if (axes == nullptr) { | ||||
| MS_LOG(ERROR) << "axes is nullptr"; | MS_LOG(ERROR) << "axes is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| for (size_t i = 0; i < axes->size(); ++i) { | for (size_t i = 0; i < axes->size(); ++i) { | ||||
| slice_param->axis_[i] = axes->Get(i); | |||||
| param->axis_[i] = axes->Get(i); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(slice_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_SliceFusion, PopulateSliceParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_SliceFusion, PopulateSliceParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,36 +19,37 @@ using mindspore::schema::PrimitiveType_Softmax; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateSoftmaxParameter(const void *prim) { | OpParameter *PopulateSoftmaxParameter(const void *prim) { | ||||
| auto *softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||||
| if (softmax_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SoftmaxParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(softmax_param, 0, sizeof(SoftmaxParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| softmax_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto prim_softmax = primitive->value_as_Softmax(); | |||||
| if (prim_softmax == nullptr) { | |||||
| MS_LOG(ERROR) << "prim_softmax is nullptr"; | |||||
| auto value = primitive->value_as_Softmax(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SoftmaxParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto axis = prim_softmax->axis(); | |||||
| memset(param, 0, sizeof(SoftmaxParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto axis = value->axis(); | |||||
| if (axis == nullptr) { | if (axis == nullptr) { | ||||
| MS_LOG(ERROR) << "axis is nullptr"; | MS_LOG(ERROR) << "axis is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (axis->size() != 1) { | if (axis->size() != 1) { | ||||
| MS_LOG(ERROR) << "axis number invalid!number: " << axis->size(); | MS_LOG(ERROR) << "axis number invalid!number: " << axis->size(); | ||||
| free(softmax_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| softmax_param->axis_ = axis->data()[0]; | |||||
| return reinterpret_cast<OpParameter *>(softmax_param); | |||||
| param->axis_ = axis->data()[0]; | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Softmax, PopulateSoftmaxParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Softmax, PopulateSoftmaxParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,42 +19,45 @@ using mindspore::schema::PrimitiveType_SpaceToBatchND; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) { | OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) { | ||||
| auto *space_batch_param_nd = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter))); | |||||
| if (space_batch_param_nd == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(space_batch_param_nd, 0, sizeof(SpaceToBatchParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| space_batch_param_nd->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_SpaceToBatchND(); | |||||
| auto value = primitive->value_as_SpaceToBatchND(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *param = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter))); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "param is nullptr"; | |||||
| MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto block_shape = param->block_shape(); | |||||
| memset(param, 0, sizeof(SpaceToBatchParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto block_shape = value->block_shape(); | |||||
| if (block_shape == nullptr) { | if (block_shape == nullptr) { | ||||
| return reinterpret_cast<OpParameter *>(space_batch_param_nd); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| auto block_shapes = std::vector<int64_t>(block_shape->begin(), block_shape->end()); | auto block_shapes = std::vector<int64_t>(block_shape->begin(), block_shape->end()); | ||||
| if (block_shapes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) { | if (block_shapes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) { | ||||
| MS_LOG(ERROR) << "The value of block_shapes.size() is too big"; | MS_LOG(ERROR) << "The value of block_shapes.size() is too big"; | ||||
| free(space_batch_param_nd); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| space_batch_param_nd->m_ = block_shapes.size(); | |||||
| param->m_ = block_shapes.size(); | |||||
| auto param_paddings = param->paddings(); | |||||
| auto param_paddings = value->paddings(); | |||||
| if (param_paddings == nullptr) { | if (param_paddings == nullptr) { | ||||
| MS_LOG(ERROR) << "param_paddings is nullptr"; | MS_LOG(ERROR) << "param_paddings is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto fb_paddings = param_paddings->data(); | auto fb_paddings = param_paddings->data(); | ||||
| if (fb_paddings == nullptr) { | if (fb_paddings == nullptr) { | ||||
| MS_LOG(ERROR) << "fb_paddings is nullptr"; | MS_LOG(ERROR) << "fb_paddings is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (fb_paddings->size() == 0 || | if (fb_paddings->size() == 0 || | ||||
| @@ -62,14 +65,15 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) { | |||||
| static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) > | static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) > | ||||
| std::numeric_limits<size_t>::max() / sizeof(int64_t))) { | std::numeric_limits<size_t>::max() / sizeof(int64_t))) { | ||||
| MS_LOG(ERROR) << "The value of paddings.size() is zero or too big"; | MS_LOG(ERROR) << "The value of paddings.size() is zero or too big"; | ||||
| free(space_batch_param_nd); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<int64_t> paddings; | std::vector<int64_t> paddings; | ||||
| for (auto iter = fb_paddings->begin(); iter != fb_paddings->end(); ++iter) { | |||||
| auto paddings_data = (*iter)->data(); | |||||
| for (auto fb_padding : *fb_paddings) { | |||||
| auto paddings_data = fb_padding->data(); | |||||
| if (paddings_data == nullptr) { | if (paddings_data == nullptr) { | ||||
| MS_LOG(ERROR) << "paddings_data is nullptr"; | MS_LOG(ERROR) << "paddings_data is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end()); | auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end()); | ||||
| @@ -77,17 +81,16 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) { | |||||
| } | } | ||||
| for (size_t i = 0; i < block_shapes.size(); ++i) { | for (size_t i = 0; i < block_shapes.size(); ++i) { | ||||
| space_batch_param_nd->block_sizes_[i] = static_cast<int>(block_shapes[i]); | |||||
| param->block_sizes_[i] = static_cast<int>(block_shapes[i]); | |||||
| } | } | ||||
| space_batch_param_nd->m_ = block_shapes.size(); | |||||
| param->m_ = block_shapes.size(); | |||||
| for (size_t i = 0; i < paddings.size(); ++i) { | for (size_t i = 0; i < paddings.size(); ++i) { | ||||
| space_batch_param_nd->paddings_[i] = static_cast<int>(paddings[i]); | |||||
| param->paddings_[i] = static_cast<int>(paddings[i]); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(space_batch_param_nd); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_SpaceToBatchND, PopulateSpaceToBatchNDParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_SpaceToBatchND, PopulateSpaceToBatchNDParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,43 +19,47 @@ using mindspore::schema::PrimitiveType_SpaceToBatch; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateSpaceToBatchParameter(const void *prim) { | OpParameter *PopulateSpaceToBatchParameter(const void *prim) { | ||||
| auto *space_batch_param = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter))); | |||||
| if (space_batch_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(space_batch_param, 0, sizeof(SpaceToBatchParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| space_batch_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_SpaceToBatch(); | |||||
| if (param == nullptr) { | |||||
| auto value = primitive->value_as_SpaceToBatch(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "param is nullptr"; | MS_LOG(ERROR) << "param is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto block_size = param->block_size(); | |||||
| auto *param = reinterpret_cast<SpaceToBatchParameter *>(malloc(sizeof(SpaceToBatchParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SpaceToBatchParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto block_size = value->block_size(); | |||||
| if (block_size == nullptr) { | if (block_size == nullptr) { | ||||
| MS_LOG(ERROR) << "block_size is nullptr"; | MS_LOG(ERROR) << "block_size is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto block_sizes = std::vector<int64_t>(block_size->begin(), block_size->end()); | auto block_sizes = std::vector<int64_t>(block_size->begin(), block_size->end()); | ||||
| if (block_sizes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) { | if (block_sizes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) { | ||||
| MS_LOG(ERROR) << "The value of block_sizes.size() is too big"; | MS_LOG(ERROR) << "The value of block_sizes.size() is too big"; | ||||
| free(space_batch_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| space_batch_param->m_ = block_sizes.size(); | |||||
| param->m_ = block_sizes.size(); | |||||
| auto param_paddings = param->paddings(); | |||||
| auto param_paddings = value->paddings(); | |||||
| if (param_paddings == nullptr) { | if (param_paddings == nullptr) { | ||||
| MS_LOG(ERROR) << "param_paddings is nullptr"; | MS_LOG(ERROR) << "param_paddings is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto fb_paddings = param_paddings->data(); | auto fb_paddings = param_paddings->data(); | ||||
| if (fb_paddings == nullptr) { | if (fb_paddings == nullptr) { | ||||
| MS_LOG(ERROR) << "fb_paddings is nullptr"; | MS_LOG(ERROR) << "fb_paddings is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (fb_paddings->size() == 0 || | if (fb_paddings->size() == 0 || | ||||
| @@ -63,14 +67,15 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) { | |||||
| static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) > | static_cast<uint64_t>(fb_paddings->size() * (*(fb_paddings->begin()))->data()->size()) > | ||||
| std::numeric_limits<size_t>::max() / sizeof(int64_t))) { | std::numeric_limits<size_t>::max() / sizeof(int64_t))) { | ||||
| MS_LOG(ERROR) << "The value of paddings.size() is zero or too big"; | MS_LOG(ERROR) << "The value of paddings.size() is zero or too big"; | ||||
| free(space_batch_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<int64_t> paddings; | std::vector<int64_t> paddings; | ||||
| for (auto iter = fb_paddings->begin(); iter != fb_paddings->end(); ++iter) { | |||||
| auto paddings_data = (*iter)->data(); | |||||
| for (auto fb_padding : *fb_paddings) { | |||||
| auto paddings_data = fb_padding->data(); | |||||
| if (paddings_data == nullptr) { | if (paddings_data == nullptr) { | ||||
| MS_LOG(ERROR) << "paddings_data is nullptr"; | MS_LOG(ERROR) << "paddings_data is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end()); | auto paddings_vec = std::vector<int64_t>(paddings_data->begin(), paddings_data->end()); | ||||
| @@ -78,15 +83,15 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) { | |||||
| } | } | ||||
| for (size_t i = 0; i < block_sizes.size(); ++i) { | for (size_t i = 0; i < block_sizes.size(); ++i) { | ||||
| space_batch_param->block_sizes_[i] = static_cast<int>(block_sizes[i]); | |||||
| param->block_sizes_[i] = static_cast<int>(block_sizes[i]); | |||||
| } | } | ||||
| for (size_t i = 0; i < paddings.size(); ++i) { | for (size_t i = 0; i < paddings.size(); ++i) { | ||||
| space_batch_param->paddings_[i] = static_cast<int>(paddings[i]); | |||||
| param->paddings_[i] = static_cast<int>(paddings[i]); | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(space_batch_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_SpaceToBatch, PopulateSpaceToBatchParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_SpaceToBatch, PopulateSpaceToBatchParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,12 +20,6 @@ using mindspore::schema::PrimitiveType_SpaceToDepth; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSpaceToDepthParameter(const void *prim) { | OpParameter *PopulateSpaceToDepthParameter(const void *prim) { | ||||
| auto *space_depth_param = reinterpret_cast<SpaceToDepthParameter *>(malloc(sizeof(SpaceToDepthParameter))); | |||||
| if (space_depth_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SpaceToDepthParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(space_depth_param, 0, sizeof(SpaceToDepthParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_SpaceToDepth(); | auto value = primitive->value_as_SpaceToDepth(); | ||||
| @@ -33,15 +27,24 @@ OpParameter *PopulateSpaceToDepthParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| space_depth_param->op_parameter_.type_ = primitive->value_type(); | |||||
| space_depth_param->block_size_ = value->block_size(); | |||||
| auto *param = reinterpret_cast<SpaceToDepthParameter *>(malloc(sizeof(SpaceToDepthParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SpaceToDepthParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SpaceToDepthParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->block_size_ = value->block_size(); | |||||
| if (value->format() != schema::Format::Format_NHWC) { | if (value->format() != schema::Format::Format_NHWC) { | ||||
| MS_LOG(ERROR) << "Currently only NHWC format is supported."; | MS_LOG(ERROR) << "Currently only NHWC format is supported."; | ||||
| free(space_depth_param); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(space_depth_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_SpaceToDepth, PopulateSpaceToDepthParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_SpaceToDepth, PopulateSpaceToDepthParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,18 +20,20 @@ using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) { | OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) { | ||||
| auto *softmax_cross_entropy_param_ = | |||||
| reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter))); | |||||
| if (softmax_cross_entropy_param_ == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed."; | MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(softmax_cross_entropy_param_, 0, sizeof(SoftmaxCrossEntropyParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| softmax_cross_entropy_param_->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(softmax_cross_entropy_param_); | |||||
| memset(param, 0, sizeof(SoftmaxCrossEntropyParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, PopulateSparseSoftmaxCrossEntropyWithLogitsParameter, | REG_POPULATE(PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, PopulateSparseSoftmaxCrossEntropyWithLogitsParameter, | ||||
| SCHEMA_CUR); | SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,20 +19,20 @@ using mindspore::schema::PrimitiveType_SparseToDense; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateSparseToDenseParameter(const void *prim) { | OpParameter *PopulateSparseToDenseParameter(const void *prim) { | ||||
| auto *sparse_to_dense_param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||||
| if (sparse_to_dense_param == nullptr) { | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto *param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SparseToDenseParameter failed."; | MS_LOG(ERROR) << "malloc SparseToDenseParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(sparse_to_dense_param, 0, sizeof(SparseToDenseParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| sparse_to_dense_param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(sparse_to_dense_param); | |||||
| memset(param, 0, sizeof(SparseToDenseParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_SparseToDense, PopulateSparseToDenseParameter, SCHEMA_CUR); | REG_POPULATE(PrimitiveType_SparseToDense, PopulateSparseToDenseParameter, SCHEMA_CUR); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -17,71 +17,75 @@ | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/splice_parameter.h" | #include "nnacl/splice_parameter.h" | ||||
| using mindspore::schema::PrimitiveType_Splice; | using mindspore::schema::PrimitiveType_Splice; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSpliceParameter(const void *prim) { | OpParameter *PopulateSpliceParameter(const void *prim) { | ||||
| auto *splice_parameter = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter))); | |||||
| if (splice_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc Splice Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(splice_parameter, 0, sizeof(SpliceParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto splice_primitive = primitive->value_as_Splice(); | |||||
| if (splice_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "splice_primitive is nullptr"; | |||||
| auto value = primitive->value_as_Splice(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| splice_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto context = splice_primitive->context(); | |||||
| auto *param = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc Splice Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SpliceParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto context = value->context(); | |||||
| if (context == nullptr) { | if (context == nullptr) { | ||||
| MS_LOG(ERROR) << "context is nullptr"; | MS_LOG(ERROR) << "context is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<int> primitive_context(context->begin(), context->end()); | std::vector<int> primitive_context(context->begin(), context->end()); | ||||
| splice_parameter->context_dim_ = static_cast<int>(primitive_context.size()); | |||||
| param->context_dim_ = static_cast<int>(primitive_context.size()); | |||||
| // malloc && memset for context | // malloc && memset for context | ||||
| splice_parameter->context_ = reinterpret_cast<int *>(malloc(splice_parameter->context_dim_ * sizeof(int))); | |||||
| if (splice_parameter->context_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc splice_parameter context_ error"; | |||||
| free(splice_parameter); | |||||
| param->context_ = reinterpret_cast<int *>(malloc(param->context_dim_ * sizeof(int))); | |||||
| if (param->context_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param context_ error"; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // src_to_dst_row_offset | // src_to_dst_row_offset | ||||
| int src_to_dst_row_offset = INT32_MIN; | int src_to_dst_row_offset = INT32_MIN; | ||||
| memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int)); | |||||
| for (int i = 0; i < splice_parameter->context_dim_; ++i) { | |||||
| splice_parameter->context_[i] = primitive_context.at(i); | |||||
| memset(param->context_, 0, param->context_dim_ * sizeof(int)); | |||||
| for (int i = 0; i < param->context_dim_; ++i) { | |||||
| param->context_[i] = primitive_context.at(i); | |||||
| src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); | src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); | ||||
| } | } | ||||
| auto forward_indexes = splice_primitive->forward_indexes(); | |||||
| auto forward_indexes = value->forward_indexes(); | |||||
| if (forward_indexes == nullptr) { | if (forward_indexes == nullptr) { | ||||
| MS_LOG(ERROR) << "forward_indexes is nullptr"; | MS_LOG(ERROR) << "forward_indexes is nullptr"; | ||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<int> primitive_forward_indexes(forward_indexes->begin(), forward_indexes->end()); | std::vector<int> primitive_forward_indexes(forward_indexes->begin(), forward_indexes->end()); | ||||
| splice_parameter->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size()); | |||||
| param->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size()); | |||||
| // malloc && memset for forward_indexes | // malloc && memset for forward_indexes | ||||
| splice_parameter->forward_indexes_ = | |||||
| reinterpret_cast<int *>(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int))); | |||||
| if (splice_parameter->forward_indexes_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error"; | |||||
| free(splice_parameter->context_); | |||||
| free(splice_parameter); | |||||
| param->forward_indexes_ = reinterpret_cast<int *>(malloc(param->forward_indexes_dim_ * sizeof(int))); | |||||
| if (param->forward_indexes_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param forward_indexes_ error"; | |||||
| free(param->context_); | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int)); | |||||
| for (int i = 0; i < splice_parameter->context_dim_; ++i) { | |||||
| splice_parameter->context_[i] = primitive_context.at(i); | |||||
| memset(param->forward_indexes_, 0, param->forward_indexes_dim_ * sizeof(int)); | |||||
| for (int i = 0; i < param->context_dim_; ++i) { | |||||
| param->context_[i] = primitive_context.at(i); | |||||
| } | } | ||||
| splice_parameter->output_dim_ = splice_primitive->output_dim(); | |||||
| return reinterpret_cast<OpParameter *>(splice_parameter); | |||||
| param->output_dim_ = value->output_dim(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,15 +19,7 @@ using mindspore::schema::PrimitiveType_Split; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateSplitParameter(const void *prim) { | OpParameter *PopulateSplitParameter(const void *prim) { | ||||
| auto *split_param = reinterpret_cast<SplitParameter *>(malloc(sizeof(SplitParameter))); | |||||
| if (split_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SplitParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(split_param, 0, sizeof(SplitParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_Split(); | auto value = primitive->value_as_Split(); | ||||
| @@ -35,36 +27,44 @@ OpParameter *PopulateSplitParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| split_param->op_parameter_.type_ = primitive->value_type(); | |||||
| split_param->num_split_ = value->output_num(); | |||||
| if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) { | |||||
| MS_LOG(ERROR) << "The value of split_param->num_split_ is too big"; | |||||
| free(split_param); | |||||
| auto *param = reinterpret_cast<SplitParameter *>(malloc(sizeof(SplitParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SplitParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SplitParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->num_split_ = value->output_num(); | |||||
| if (param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) { | |||||
| MS_LOG(ERROR) << "The value of param->num_split_ is too big"; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| /* free split_sizes_ in split op base */ | /* free split_sizes_ in split op base */ | ||||
| split_param->split_sizes_ = reinterpret_cast<int *>(malloc(split_param->num_split_ * sizeof(int))); | |||||
| if (split_param->split_sizes_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc split_param split_sizes_ error"; | |||||
| free(split_param); | |||||
| param->split_sizes_ = reinterpret_cast<int *>(malloc(param->num_split_ * sizeof(int))); | |||||
| if (param->split_sizes_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param split_sizes_ error"; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(split_param->split_sizes_, 0, split_param->num_split_ * sizeof(int)); | |||||
| memset(param->split_sizes_, 0, param->num_split_ * sizeof(int)); | |||||
| auto split_sizes_vector_ = value->size_splits(); | auto split_sizes_vector_ = value->size_splits(); | ||||
| if (split_sizes_vector_ != nullptr) { | if (split_sizes_vector_ != nullptr) { | ||||
| int i = 0; | int i = 0; | ||||
| for (auto iter : *split_sizes_vector_) { | for (auto iter : *split_sizes_vector_) { | ||||
| split_param->split_sizes_[i++] = iter; | |||||
| param->split_sizes_[i++] = iter; | |||||
| } | } | ||||
| split_param->split_count_ = split_param->num_split_; | |||||
| param->split_count_ = param->num_split_; | |||||
| } else { | } else { | ||||
| split_param->split_count_ = 0; | |||||
| param->split_count_ = 0; | |||||
| } | } | ||||
| split_param->split_dim_ = value->axis(); | |||||
| return reinterpret_cast<OpParameter *>(split_param); | |||||
| param->split_dim_ = value->axis(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Split, PopulateSplitParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Split, PopulateSplitParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,48 +20,57 @@ using mindspore::schema::PrimitiveType_SplitWithOverlap; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { | OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { | ||||
| auto *split_with_over_lap_param = | |||||
| reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter))); | |||||
| if (split_with_over_lap_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed."; | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| auto value = primitive->value_as_SplitWithOverlap(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| auto value = primitive->value_as_SplitWithOverlap(); | |||||
| split_with_over_lap_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto *param = reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(SplitWithOverlapParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto ratio = value->ratio(); | auto ratio = value->ratio(); | ||||
| if (ratio == nullptr) { | |||||
| MS_LOG(ERROR) << "ratio is nullptr"; | |||||
| free(param); | |||||
| return nullptr; | |||||
| } | |||||
| if (ratio->size() > SPLIT_MAX_SLICE_NUM) { | if (ratio->size() > SPLIT_MAX_SLICE_NUM) { | ||||
| MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM | MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM | ||||
| << " slices"; | << " slices"; | ||||
| delete split_with_over_lap_param; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| split_with_over_lap_param->num_split_ = static_cast<int>(ratio->size()); | |||||
| split_with_over_lap_param->split_dim_ = value->split_dim(); | |||||
| param->num_split_ = static_cast<int>(ratio->size()); | |||||
| param->split_dim_ = value->split_dim(); | |||||
| auto extend_top = value->extend_top(); | auto extend_top = value->extend_top(); | ||||
| auto extend_bottom = value->extend_bottom(); | auto extend_bottom = value->extend_bottom(); | ||||
| if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) { | |||||
| if (extend_top->size() != ratio->size() || (extend_bottom != nullptr && extend_bottom->size() != ratio->size())) { | |||||
| MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; | MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; | ||||
| delete split_with_over_lap_param; | |||||
| free(param); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| for (size_t i = 0; i < ratio->size(); ++i) { | for (size_t i = 0; i < ratio->size(); ++i) { | ||||
| split_with_over_lap_param->ratio_[i] = (*ratio)[i]; | |||||
| split_with_over_lap_param->extend_top_[i] = (*extend_top)[i]; | |||||
| split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i]; | |||||
| param->ratio_[i] = (*ratio)[i]; | |||||
| param->extend_top_[i] = (*extend_top)[i]; | |||||
| param->extend_bottom_[i] = (*extend_bottom)[i]; | |||||
| } | } | ||||
| split_with_over_lap_param->stride_ = value->stride(); | |||||
| split_with_over_lap_param->pad_top_ = value->pad_top(); | |||||
| param->stride_ = value->stride(); | |||||
| param->pad_top_ = value->pad_top(); | |||||
| return reinterpret_cast<OpParameter *>(split_with_over_lap_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,36 +19,35 @@ using mindspore::schema::PrimitiveType_Squeeze; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateSqueezeParameter(const void *prim) { | OpParameter *PopulateSqueezeParameter(const void *prim) { | ||||
| SqueezeParameter *squeeze_param = reinterpret_cast<SqueezeParameter *>(malloc(sizeof(SqueezeParameter))); | |||||
| if (squeeze_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SqueezeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(squeeze_param, 0, sizeof(SqueezeParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | auto *primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| squeeze_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto value = primitive->value_as_Squeeze(); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto squeeze_prim = primitive->value_as_Squeeze(); | |||||
| if (squeeze_prim == nullptr) { | |||||
| MS_LOG(ERROR) << "squeeze_prim is nullptr"; | |||||
| auto *param = reinterpret_cast<SqueezeParameter *>(malloc(sizeof(SqueezeParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SqueezeParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto axis = squeeze_prim->axis(); | |||||
| if (squeeze_prim->axis() != nullptr) { | |||||
| squeeze_param->axis_size_ = axis->size(); | |||||
| for (size_t i = 0; i < squeeze_param->axis_size_; i++) { | |||||
| squeeze_param->axis_[i] = *(axis->begin() + i); | |||||
| memset(param, 0, sizeof(SqueezeParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto axis = value->axis(); | |||||
| if (axis != nullptr) { | |||||
| param->axis_size_ = axis->size(); | |||||
| for (size_t i = 0; i < param->axis_size_; i++) { | |||||
| param->axis_[i] = *(axis->begin() + i); | |||||
| } | } | ||||
| } else { | } else { | ||||
| squeeze_param->axis_size_ = 0; | |||||
| param->axis_size_ = 0; | |||||
| } | } | ||||
| return reinterpret_cast<OpParameter *>(squeeze_param); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Squeeze, PopulateSqueezeParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Squeeze, PopulateSqueezeParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,14 +19,7 @@ using mindspore::schema::PrimitiveType_Stack; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | |||||
| OpParameter *PopulateStackParameter(const void *prim) { | OpParameter *PopulateStackParameter(const void *prim) { | ||||
| auto *stack_param = reinterpret_cast<StackParameter *>(malloc(sizeof(StackParameter))); | |||||
| if (stack_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc StackParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(stack_param, 0, sizeof(StackParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto value = primitive->value_as_Stack(); | auto value = primitive->value_as_Stack(); | ||||
| @@ -34,11 +27,19 @@ OpParameter *PopulateStackParameter(const void *prim) { | |||||
| MS_LOG(ERROR) << "value is nullptr"; | MS_LOG(ERROR) << "value is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| stack_param->op_parameter_.type_ = primitive->value_type(); | |||||
| stack_param->axis_ = static_cast<int>(value->axis()); | |||||
| return reinterpret_cast<OpParameter *>(stack_param); | |||||
| auto *param = reinterpret_cast<StackParameter *>(malloc(sizeof(StackParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc StackParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(StackParameter)); | |||||
| param->op_parameter_.type_ = primitive->value_type(); | |||||
| param->axis_ = static_cast<int>(value->axis()); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | } | ||||
| } // namespace | |||||
| REG_POPULATE(PrimitiveType_Stack, PopulateStackParameter, SCHEMA_CUR) | REG_POPULATE(PrimitiveType_Stack, PopulateStackParameter, SCHEMA_CUR) | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||