|
|
|
@@ -90,23 +90,29 @@ STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
auto bnPath = matchedPath.at(bnOpName); |
|
|
|
status = GetTransParam(graph, bnPath); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GetTransParam failed: " << status; |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
status = GenNewScaleTensor(graph, bnPath); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
transScale = nullptr; |
|
|
|
transBias = nullptr; |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
status = ConvertBNToScale(graph, bnPath); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
transScale = nullptr; |
|
|
|
transBias = nullptr; |
|
|
|
return status; |
|
|
|
} |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
transScale = nullptr; |
|
|
|
transBias = nullptr; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) { |
|
|
|
@@ -245,6 +251,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh |
|
|
|
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) |
|
|
|
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) { |
|
|
|
MS_LOG(ERROR) << "memcpy_s transScale error"; |
|
|
|
delete[] transScale; |
|
|
|
delete[] transBias; |
|
|
|
transScale = nullptr; |
|
|
|
transBias = nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// 1/sqrt(variance + eps) |
|
|
|
@@ -370,14 +380,5 @@ STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) { |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
BatchNormConvertScalePass::~BatchNormConvertScalePass() { |
|
|
|
if (this->transScale != nullptr) { |
|
|
|
delete (this->transScale); |
|
|
|
} |
|
|
|
if (this->transBias != nullptr) { |
|
|
|
delete (this->transBias); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |