|
|
|
@@ -40,6 +40,7 @@ namespace { |
|
|
|
constexpr const float EPS = 1e-8; |
|
|
|
constexpr const float EPS_DEFAULT_FLOAT = 1e-8; |
|
|
|
constexpr const float POW_NUM = 0.5; |
|
|
|
constexpr uint32_t kQuadrupleNum = 4; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { |
|
|
|
@@ -52,6 +53,11 @@ STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto input_index = node->inputIndex.at(0); |
|
|
|
if (graph->allTensors.at(input_index)->dims.empty()) { |
|
|
|
MS_LOG(WARNING) << "The shape of input tensor is uncertain."; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
auto status = GenNewScaleTensor(graph, node); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; |
|
|
|
@@ -75,9 +81,13 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std: |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// after fusion bn must NHWC |
|
|
|
scaleParam->axis = -1; |
|
|
|
bnNode->primitive->value.value = scaleParam.release(); |
|
|
|
auto input0 = bnNode->inputIndex.at(0); |
|
|
|
if (graph->allTensors.at(input0)->dims.size() == kQuadrupleNum) { |
|
|
|
scaleParam->axis = -1; |
|
|
|
} else { |
|
|
|
scaleParam->axis = 1; |
|
|
|
} |
|
|
|
bnNode->primitive->value.value = scaleParam.release(); |
|
|
|
bnNode->inputIndex.clear(); |
|
|
|
bnNode->inputIndex.push_back(input0); |
|
|
|
graph->allTensors.emplace_back(std::move(newScaleWeightTensor)); |
|
|
|
|