From 2167833b550d59a7f627eccb0f29cb298cffe147 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Tue, 26 Jan 2021 17:04:23 +0800 Subject: [PATCH] auto gen fbs --- build.sh | 19 +++++++++++++++++++ mindspore/lite/schema/ops.fbs | 2 +- mindspore/lite/src/ops/ops_def.cc | 2 +- .../src/train/train_populate_parameter.cc | 18 +++++++++--------- 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/build.sh b/build.sh index 4349547c03..d89337ca08 100755 --- a/build.sh +++ b/build.sh @@ -501,6 +501,24 @@ write_commit_file() { echo ${COMMIT_STR} > "${BASEPATH}/mindspore/lite/build/.commit_id" } +gen_fbs() { + if [[ "${ENABLE_TOOLS}" == "on" ]]; then + if [[ -f ${BASEPATH}/mindspore/lite/build/tools/schema_gen/schema_gen ]]; then + cd ${BASEPATH}/mindspore/lite/build/tools/schema_gen + ./schema_gen + cd - + diff_ops=$(diff ${BASEPATH}/mindspore/lite/build/tools/schema_gen/ops.fbs ${BASEPATH}/mindspore/lite/schema/ops.fbs || true) + if [[ "X${diff_ops}" != "X" ]]; then + cp ${BASEPATH}/mindspore/lite/build/tools/schema_gen/ops.fbs ${BASEPATH}/mindspore/lite/schema/ + fi + diff_types=$(diff ${BASEPATH}/mindspore/lite/build/tools/schema_gen/primitive_type.fbs ${BASEPATH}/mindspore/lite/schema/primitive_type.fbs || true) + if [[ "X${diff_types}" != "X" ]]; then + cp ${BASEPATH}/mindspore/lite/build/tools/schema_gen/primitive_type.fbs ${BASEPATH}/mindspore/lite/schema/ + fi + fi + fi +} + build_lite() { get_version @@ -572,6 +590,7 @@ build_lite() echo "---------------- mindspore lite: build failed ----------------" exit 1 else + gen_fbs mv ${BASEPATH}/output/tmp/*.tar.gz* ${BASEPATH}/output/ rm -rf ${BASEPATH}/output/tmp/ echo "---------------- mindspore lite: build success ----------------" diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 9b4377b2e0..4895bf914c 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -285,7 +285,7 @@ table Dropout { } table DropoutGrad { - ratio: float; + keep_prob: float; } table Elu { diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index c037e6f1ce..2155a7327e 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -455,7 +455,7 @@ OP_ATTR_WITH_VALUE(keep_prob, float, 0.5) OP_SCHEMA_DEF_END(Dropout) OP_SCHEMA_DEF(DropoutGrad) -OP_ATTR(ratio, float) +OP_ATTR(keep_prob, float) OP_SCHEMA_DEF_END(DropoutGrad) OP_SCHEMA_DEF(Elu) diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index 188b898368..760c7c4477 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -389,22 +389,22 @@ OpParameter *PopulateDropoutParameter(const void *prim) { } OpParameter *PopulateDropoutGradParameter(const void *prim) { - DropoutParameter *dropoutGrad_parameter = reinterpret_cast(malloc(sizeof(DropoutParameter))); - if (dropoutGrad_parameter == nullptr) { + DropoutParameter *dropoutgrad_parameter = reinterpret_cast(malloc(sizeof(DropoutParameter))); + if (dropoutgrad_parameter == nullptr) { MS_LOG(ERROR) << "malloc Dropout Grad Parameter failed."; return nullptr; } - memset(dropoutGrad_parameter, 0, sizeof(DropoutParameter)); + memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter)); auto primitive = static_cast(prim); auto value = primitive->value_as_DropoutGrad(); - dropoutGrad_parameter->op_parameter_.type_ = primitive->value_type(); - dropoutGrad_parameter->ratio_ = value->ratio(); - if (dropoutGrad_parameter->ratio_ < 0.f || dropoutGrad_parameter->ratio_ > 1.f) { - MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutGrad_parameter->ratio_; - free(dropoutGrad_parameter); + dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type(); + dropoutgrad_parameter->ratio_ = value->keep_prob(); + if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) { + MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutgrad_parameter->ratio_; + free(dropoutgrad_parameter); return nullptr; } - return reinterpret_cast(dropoutGrad_parameter); + return reinterpret_cast(dropoutgrad_parameter); } OpParameter *PopulateArithmeticGradParameter(const void *prim) {