Browse Source

[MSLITE] batchnorm to scale bug

tags/v1.2.0-rc1
ling 5 years ago
parent
commit
f7cbeb1fe5
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc

+ 4
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc View File

@@ -22,6 +22,7 @@
#include "tools/converter/converter_flags.h"
#include "third_party/securec/include/securec.h"
#include "src/common/log_adapter.h"
#include "src/common/common.h"
#include "tools/common/tensor_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
@@ -39,7 +40,6 @@ namespace {
constexpr const float EPS = 1e-8;
constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
constexpr const float POW_NUM = 0.5;
constexpr const int32_t NCHW_DIM_C = 1;
} // namespace

STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
@@ -74,7 +74,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
MS_LOG(ERROR) << "new scaleParam failed";
return RET_ERROR;
}
scaleParam->axis = NCHW_DIM_C;
int32_t axis =
(graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C;
scaleParam->axis = axis;
bnNode->primitive->value.value = scaleParam.release();
auto input0 = bnNode->inputIndex.at(0);
bnNode->inputIndex.clear();


Loading…
Cancel
Save