|
|
|
@@ -298,10 +298,6 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap |
|
|
|
MS_LOG(ERROR) << "config should be specified"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (old_graph->has_flag("HasTransformed")) { |
|
|
|
old_graph->set_flag("HasTransformed", false); |
|
|
|
return old_graph; |
|
|
|
} |
|
|
|
|
|
|
|
auto status = RunPrecedingPass(old_graph, *config); |
|
|
|
if (status != RET_OK) { |
|
|
|
@@ -356,43 +352,44 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap |
|
|
|
return new_graph; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphVector *subgraphs, |
|
|
|
std::vector<ValueNodePtr> *vnodes) { |
|
|
|
auto nodes = TopoSort(main_graph->get_return()); |
|
|
|
void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) { |
|
|
|
if (func_graphs_.find(func_graph) == func_graphs_.end()) { |
|
|
|
func_graphs_.insert(func_graph); |
|
|
|
} else { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto nodes = func_graph->nodes(); |
|
|
|
for (auto &node : nodes) { |
|
|
|
auto fg = GetValueNode<FuncGraphPtr>(node); |
|
|
|
if (fg) { |
|
|
|
vnodes->push_back(utils::cast<ValueNodePtr>(node)); |
|
|
|
subgraphs->push_back(fg); |
|
|
|
if (IsValueNode<FuncGraph>(node)) { |
|
|
|
auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>(); |
|
|
|
GetAllFuncGraph(new_fg); |
|
|
|
} |
|
|
|
if (utils::isa<CNodePtr>(node)) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
for (auto &input : cnode->inputs()) { |
|
|
|
if (input->isa<ValueNode>()) { |
|
|
|
if (IsValueNode<FuncGraph>(input)) { |
|
|
|
auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>(); |
|
|
|
GetAllFuncGraph(new_fg); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { |
|
|
|
// transform main_graph |
|
|
|
auto new_main_graph = TransformSingleFuncGraph(main_graph, config); |
|
|
|
if (new_main_graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
GetAllFuncGraph(main_graph); |
|
|
|
|
|
|
|
// transform sub_graph |
|
|
|
FuncGraphVector subgraphs{}; |
|
|
|
std::vector<ValueNodePtr> vnodes{}; |
|
|
|
int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < subgraphs.size(); i++) { |
|
|
|
auto new_graph = Transform(subgraphs.at(i), config); |
|
|
|
new_graph->set_flag("HasTransformed", true); |
|
|
|
vnodes.at(i)->set_value(new_graph); |
|
|
|
for (auto &fg : func_graphs_) { |
|
|
|
auto new_main_graph = TransformSingleFuncGraph(fg, config); |
|
|
|
if (new_main_graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return new_main_graph; |
|
|
|
return main_graph; |
|
|
|
} |
|
|
|
} // namespace mindspore::lite |