|
|
|
@@ -69,9 +69,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std: |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
MS_ASSERT(bnNode != nullptr); |
|
|
|
bnNode->primitive->value.type = schema::PrimitiveType_Scale; |
|
|
|
std::unique_ptr<ScaleT> scaleParam(new ScaleT()); |
|
|
|
std::unique_ptr<ScaleT> scaleParam(new (std::nothrow) ScaleT()); |
|
|
|
if (scaleParam == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new transposeParam failed"; |
|
|
|
MS_LOG(ERROR) << "new scaleParam failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
scaleParam->axis = NCHW_DIM_C; |
|
|
|
@@ -104,7 +104,7 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std |
|
|
|
newScaleWeightTensor->data.resize(weightShapeSize * sizeof(float)); |
|
|
|
auto ret = memcpy_s(newScaleWeightTensor->data.data(), weightShapeSize * sizeof(float), transScale, |
|
|
|
weightShapeSize * sizeof(float)); |
|
|
|
if (ret != RET_OK) { |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "memcpy error: " << ret; |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
@@ -127,7 +127,7 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std |
|
|
|
newScaleBiasTensor->data.resize(weightShapeSize * sizeof(float)); |
|
|
|
ret = memcpy_s(newScaleBiasTensor->data.data(), weightShapeSize * sizeof(float), transBias, |
|
|
|
weightShapeSize * sizeof(float)); |
|
|
|
if (ret != RET_OK) { |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "memcpy error: " << ret; |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
@@ -166,9 +166,17 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::un |
|
|
|
return status; |
|
|
|
} |
|
|
|
this->transScale = new (std::nothrow) float[bnChannel]; |
|
|
|
if (this->transScale == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new transScale failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->transBias = new (std::nothrow) float[bnChannel]; |
|
|
|
if (this->transBias == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new transBias failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) |
|
|
|
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) { |
|
|
|
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != EOK) { |
|
|
|
MS_LOG(ERROR) << "memcpy_s transScale error"; |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
@@ -180,6 +188,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::un |
|
|
|
for (uint32_t i = 0; i < bnChannel; i++) { |
|
|
|
float tmp = transScale[i] + eps; |
|
|
|
tmp = pow(tmp, POW_NUM); |
|
|
|
if (tmp <= 0.0f) { |
|
|
|
MS_LOG(ERROR) << "divisor 'tmp' cannot be 0"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
transScale[i] = 1 / tmp; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -278,6 +290,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight |
|
|
|
STATUS BatchNormConvertScalePass::GetBnEpsilon(const std::unique_ptr<CNodeT> &bnNode) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
MS_ASSERT(bnNode != nullptr); |
|
|
|
MS_ASSERT(bnNode->primitive != nullptr); |
|
|
|
if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { |
|
|
|
eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; |
|
|
|
} else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) { |
|
|
|
|