|
|
|
@@ -22,6 +22,7 @@ |
|
|
|
#include "tools/converter/converter_flags.h" |
|
|
|
#include "third_party/securec/include/securec.h" |
|
|
|
#include "src/common/log_adapter.h" |
|
|
|
#include "src/common/common.h" |
|
|
|
#include "tools/common/tensor_util.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "schema/inner/model_generated.h" |
|
|
|
@@ -39,7 +40,6 @@ namespace { |
|
|
|
constexpr const float EPS = 1e-8; |
|
|
|
constexpr const float EPS_DEFAULT_FLOAT = 1e-8; |
|
|
|
constexpr const float POW_NUM = 0.5; |
|
|
|
constexpr const int32_t NCHW_DIM_C = 1; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { |
|
|
|
@@ -74,7 +74,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std: |
|
|
|
MS_LOG(ERROR) << "new scaleParam failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
scaleParam->axis = NCHW_DIM_C; |
|
|
|
int32_t axis = |
|
|
|
(graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C; |
|
|
|
scaleParam->axis = axis; |
|
|
|
bnNode->primitive->value.value = scaleParam.release(); |
|
|
|
auto input0 = bnNode->inputIndex.at(0); |
|
|
|
bnNode->inputIndex.clear(); |
|
|
|
|