|
|
|
@@ -325,12 +325,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (!opt::RunExternalPass(old_graph, opt::POSITION_BEGIN)) { |
|
|
|
if (!RunExternalPass(old_graph, PassPosition::POSITION_BEGIN)) { |
|
|
|
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { |
|
|
|
if (!RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { |
|
|
|
MS_LOG(ERROR) << "Run transpose opt pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -355,12 +355,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { |
|
|
|
if (!RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { |
|
|
|
MS_LOG(ERROR) << "Run transpose opt pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (!opt::RunExternalPass(old_graph, opt::POSITION_END)) { |
|
|
|
if (!RunExternalPass(old_graph, PassPosition::POSITION_END)) { |
|
|
|
MS_LOG(ERROR) << "Run external pass failed, place is END"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -388,11 +388,11 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con |
|
|
|
void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) { |
|
|
|
auto fmk = config->fmk; |
|
|
|
auto is_train = config->trainModel; |
|
|
|
opt::PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)); |
|
|
|
opt::PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()); |
|
|
|
opt::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)); |
|
|
|
opt::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)); |
|
|
|
opt::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)); |
|
|
|
PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)); |
|
|
|
PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()); |
|
|
|
PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)); |
|
|
|
PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)); |
|
|
|
PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)); |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { |
|
|
|
|