| @@ -52,9 +52,10 @@ std::string GetNodeRepr(AnfNodePtr node) { | |||||
| void ResolveFuncGraph_(const FuncGraphPtr &fg) { | void ResolveFuncGraph_(const FuncGraphPtr &fg) { | ||||
| auto manager = Manage(fg, false); | auto manager = Manage(fg, false); | ||||
| auto use_sig = parse::python_adapter::UseSignatureInResolve(); | |||||
| parse::python_adapter::set_use_signature_in_resolve(false); | parse::python_adapter::set_use_signature_in_resolve(false); | ||||
| parse::ResolveAll(manager); | parse::ResolveAll(manager); | ||||
| parse::python_adapter::set_use_signature_in_resolve(true); | |||||
| parse::python_adapter::set_use_signature_in_resolve(use_sig); | |||||
| } | } | ||||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | ||||
| @@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb | |||||
| void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { | void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { | ||||
| std::string var = phi_nodes_[phi]; | std::string var = phi_nodes_[phi]; | ||||
| MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; | MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; | ||||
| auto removable = CollectRemovablePhi(phi); | |||||
| // If the phi node is not necessary, not need to add to jumps_ of the prev blocks. | |||||
| if (removable) { | |||||
| MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var; | |||||
| return; | |||||
| } | |||||
| for (auto &pred : prev_blocks_) { | for (auto &pred : prev_blocks_) { | ||||
| MS_EXCEPTION_IF_NULL(pred); | MS_EXCEPTION_IF_NULL(pred); | ||||
| MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); | MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); | ||||
| @@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { | |||||
| CNodePtr jump = pred->jumps_[this]; | CNodePtr jump = pred->jumps_[this]; | ||||
| jump->add_input(arg_node); | jump->add_input(arg_node); | ||||
| } | } | ||||
| // If the phi node in the body part of a for/while loop is being removed, | |||||
| // then the closure convert phase will generate a cycle in graph if the | |||||
| // loop is kept after specialization. This should be investigate further. | |||||
| // Just now user has to set a flag on a function to indicate the for loop | |||||
| // will definitely can be unroll as the sequence in for statement is fixed | |||||
| // size in compile time. | |||||
| if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || | |||||
| parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||||
| CollectRemovablePhi(phi); | |||||
| } | |||||
| } | } | ||||
| AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { | AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { | ||||
| @@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame | |||||
| // 2. it's costly to iterate the graph to replace the phi for each phi. | // 2. it's costly to iterate the graph to replace the phi for each phi. | ||||
| // Args : | // Args : | ||||
| // phi : This parameter node is functioning as a phi node. | // phi : This parameter node is functioning as a phi node. | ||||
| void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||||
| bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||||
| MS_EXCEPTION_IF_NULL(phi); | MS_EXCEPTION_IF_NULL(phi); | ||||
| std::string var = phi_nodes_[phi]; | std::string var = phi_nodes_[phi]; | ||||
| MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); | |||||
| MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var; | |||||
| if (prev_blocks_.size() == 0) { | if (prev_blocks_.size() == 0) { | ||||
| MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); | |||||
| return; | |||||
| MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var; | |||||
| return false; | |||||
| } | } | ||||
| AnfNodePtr arg_node = SearchReplaceNode(var, phi); | AnfNodePtr arg_node = SearchReplaceNode(var, phi); | ||||
| if (arg_node != nullptr) { | if (arg_node != nullptr) { | ||||
| @@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||||
| const auto ¶m = phi_iter.second->cast<ParameterPtr>(); | const auto ¶m = phi_iter.second->cast<ParameterPtr>(); | ||||
| if (param == phi) { | if (param == phi) { | ||||
| MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() | MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() | ||||
| << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); | |||||
| << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString() | |||||
| << " in graph " << arg_node->func_graph()->ToString(); | |||||
| prev->removable_phis_[phi_iter.first] = arg_node; | prev->removable_phis_[phi_iter.first] = arg_node; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| return false; | |||||
| } | } | ||||
| // A block should be marked matured if its predecessor blocks have been processed | // A block should be marked matured if its predecessor blocks have been processed | ||||
| @@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||||
| AnfNodePtr ReadVariable(const std::string &var_name); | AnfNodePtr ReadVariable(const std::string &var_name); | ||||
| void AddPrevBlock(const FunctionBlockPtr &block); | void AddPrevBlock(const FunctionBlockPtr &block); | ||||
| void SetPhiArgument(const ParameterPtr &phi); | void SetPhiArgument(const ParameterPtr &phi); | ||||
| void CollectRemovablePhi(const ParameterPtr &phi); | |||||
| bool CollectRemovablePhi(const ParameterPtr &phi); | |||||
| // A block is matured if all its predecessors is generated | // A block is matured if all its predecessors is generated | ||||
| void Mature(); | void Mature(); | ||||
| CNodePtr ForceToBoolNode(const AnfNodePtr &cond); | CNodePtr ForceToBoolNode(const AnfNodePtr &cond); | ||||
| @@ -1436,6 +1436,15 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje | |||||
| return block; | return block; | ||||
| } | } | ||||
| AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) { | |||||
| const auto &inp = node->cast<ParameterPtr>(); | |||||
| const auto &iter = removable_phis.find(inp); | |||||
| if (iter == removable_phis.end()) { | |||||
| return node; | |||||
| } | |||||
| return FindPhis(removable_phis, iter->second); | |||||
| } | |||||
| void Parser::RemoveUnnecessaryPhis() { | void Parser::RemoveUnnecessaryPhis() { | ||||
| // merge all removable phis to one map; | // merge all removable phis to one map; | ||||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis; | std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis; | ||||
| @@ -1443,28 +1452,39 @@ void Parser::RemoveUnnecessaryPhis() { | |||||
| MS_EXCEPTION_IF_NULL(block); | MS_EXCEPTION_IF_NULL(block); | ||||
| removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); | removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); | ||||
| } | } | ||||
| if (removable_phis.size() == 0) { | if (removable_phis.size() == 0) { | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { | |||||
| if (node->isa<CNode>()) { | |||||
| const auto &cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| for (std::size_t i = 0; i < inputs.size(); i++) { | |||||
| if (inputs[i]->isa<Parameter>()) { | |||||
| const auto &inp = inputs[i]->cast<ParameterPtr>(); | |||||
| const auto &iter = removable_phis.find(inp); | |||||
| if (iter == removable_phis.end()) { | |||||
| continue; | |||||
| } | |||||
| auto &argNode = iter->second; | |||||
| MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " | |||||
| << cnode->DebugString() << " with " << argNode->DebugString(); | |||||
| cnode->set_input(i, argNode); | |||||
| } | |||||
| } | |||||
| auto fg_name = func_graph_->ToString(); | |||||
| auto mng = Manage(func_graph_, false); | |||||
| // replace the nodes | |||||
| for (auto iter : removable_phis) { | |||||
| auto new_node = FindPhis(removable_phis, iter.first); | |||||
| MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString(); | |||||
| mng->Replace(iter.first, new_node); | |||||
| } | |||||
| // remove the parameter | |||||
| for (FunctionBlockPtr &block : func_block_list_) { | |||||
| MS_EXCEPTION_IF_NULL(block); | |||||
| auto &local_removable_phis = block->removable_phis(); | |||||
| if (local_removable_phis.size() == 0) { | |||||
| continue; | |||||
| } | } | ||||
| auto func_graph = block->func_graph(); | |||||
| auto ¶meters = func_graph->parameters(); | |||||
| std::vector<AnfNodePtr> new_parameters(parameters.size()); | |||||
| auto it = std::copy_if( | |||||
| parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) { | |||||
| return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end(); | |||||
| }); | |||||
| // shrink container to new size | |||||
| new_parameters.resize(std::distance(new_parameters.begin(), it)); | |||||
| func_graph->set_parameters(new_parameters); | |||||
| } | |||||
| for (auto fg : mng->func_graphs()) { | |||||
| fg->ClearAllManagerInfo(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -111,6 +111,27 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) { | |||||
| return sorted_nodes; | return sorted_nodes; | ||||
| } | } | ||||
| std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) { | |||||
| std::deque<FuncGraphPtr> todo; | |||||
| todo.push_back(root); | |||||
| std::vector<FuncGraphPtr> sorted; | |||||
| auto seen = NewSeenGeneration(); | |||||
| while (!todo.empty()) { | |||||
| FuncGraphPtr top = todo.front(); | |||||
| todo.pop_front(); | |||||
| sorted.push_back(top); | |||||
| auto used = top->func_graphs_used(); | |||||
| for (auto &item : used) { | |||||
| if (item.first->seen_ == seen) { | |||||
| continue; | |||||
| } | |||||
| todo.push_back(item.first); | |||||
| item.first->seen_ = seen; | |||||
| } | |||||
| } | |||||
| return sorted; | |||||
| } | |||||
| std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | ||||
| std::vector<AnfNodePtr> vecs; | std::vector<AnfNodePtr> vecs; | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| @@ -70,6 +70,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = | |||||
| const IncludeFunc &include = AlwaysInclude); | const IncludeFunc &include = AlwaysInclude); | ||||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret); | std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret); | ||||
| std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root); | |||||
| class FuncGraphIndex { | class FuncGraphIndex { | ||||
| public: | public: | ||||
| explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | ||||
| @@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| std::string Parameter::DebugString(int recursive_level) const { | |||||
| std::ostringstream buffer; | |||||
| if (recursive_level > 0) { | |||||
| if (func_graph() != nullptr) { | |||||
| buffer << func_graph()->ToString() << ":"; | |||||
| } | |||||
| } | |||||
| buffer << ToString(); | |||||
| return buffer.str(); | |||||
| } | |||||
| std::string ValueNode::ToString() const { | std::string ValueNode::ToString() const { | ||||
| MS_EXCEPTION_IF_NULL(value_); | MS_EXCEPTION_IF_NULL(value_); | ||||
| if (value_->isa<FuncGraph>()) { | if (value_->isa<FuncGraph>()) { | ||||
| @@ -249,7 +249,7 @@ class Parameter : public ANode { | |||||
| MS_DECLARE_PARENT(Parameter, ANode); | MS_DECLARE_PARENT(Parameter, ANode); | ||||
| void accept(AnfVisitor *v) override; | void accept(AnfVisitor *v) override; | ||||
| std::string DebugString(int recursive_level = 1) const override; | |||||
| std::string name() const { return name_; } | std::string name() const { return name_; } | ||||
| void set_name(const std::string &name) { name_ = name; } | void set_name(const std::string &name) { name_ = name; } | ||||
| std::string fullname_with_scope() override { return name(); }; | std::string fullname_with_scope() override { return name(); }; | ||||
| @@ -417,6 +417,15 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() { | |||||
| return mng->recursive_graphs(shared_from_base<FuncGraph>()); | return mng->recursive_graphs(shared_from_base<FuncGraph>()); | ||||
| } | } | ||||
| void FuncGraph::ClearAllManagerInfo() { | |||||
| ClearNodes(); | |||||
| ClearValueNodes(); | |||||
| ClearFuncGraphCNodesIndex(); | |||||
| ClearFreeVariables(); | |||||
| ClearFuncGraphsUsed(); | |||||
| ClearJFuncGraphs(); | |||||
| } | |||||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { | AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { | ||||
| auto itr = this->parameter_default_value_.find(name); | auto itr = this->parameter_default_value_.find(name); | ||||
| if (itr == parameter_default_value_.end()) { | if (itr == parameter_default_value_.end()) { | ||||
| @@ -229,7 +229,8 @@ class FuncGraph : public FuncGraphBase { | |||||
| } | } | ||||
| this->debug_info_ = info; | this->debug_info_ = info; | ||||
| } | } | ||||
| // clear all info from manager | |||||
| void ClearAllManagerInfo(); | |||||
| // get all nodes belonging to this func graph | // get all nodes belonging to this func graph | ||||
| const AnfNodeSet &nodes(); | const AnfNodeSet &nodes(); | ||||
| void CopyNodes(const FuncGraphPtr &source); | void CopyNodes(const FuncGraphPtr &source); | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/profile.h" | #include "utils/profile.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/graph_utils.h" | |||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -400,11 +401,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph | |||||
| } | } | ||||
| void Cloner::Lift() { | void Cloner::Lift() { | ||||
| for (auto &func_graph_params : repl_func_graph_params_) { | |||||
| auto &func_graph = func_graph_params.first; | |||||
| auto ¶ms = func_graph_params.second; | |||||
| for (auto &cnode : func_graph->func_graph_cnodes_index()) { | |||||
| LiftParameters(cnode.first->first->func_graph(), func_graph, params); | |||||
| // lift inner graph first | |||||
| auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin())); | |||||
| for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) { | |||||
| auto func_graph = *r_iter; | |||||
| auto iter = repl_func_graph_params_.find(func_graph); | |||||
| if (iter != repl_func_graph_params_.end()) { | |||||
| auto ¶ms = iter->second; | |||||
| for (auto &cnode : func_graph->func_graph_cnodes_index()) { | |||||
| LiftParameters(cnode.first->first->func_graph(), func_graph, params); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -520,12 +520,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||||
| target->CopyFuncGraphsUsed(source); | target->CopyFuncGraphsUsed(source); | ||||
| target->CopyJFuncGraphs(source); | target->CopyJFuncGraphs(source); | ||||
| signals_->InvalidateComputer(); | signals_->InvalidateComputer(); | ||||
| source->ClearNodes(); | |||||
| source->ClearValueNodes(); | |||||
| source->ClearFuncGraphCNodesIndex(); | |||||
| source->ClearFreeVariables(); | |||||
| source->ClearFuncGraphsUsed(); | |||||
| source->ClearJFuncGraphs(); | |||||
| source->ClearAllManagerInfo(); | |||||
| } | } | ||||
| FuncGraphTransaction FuncGraphManager::Transact() { | FuncGraphTransaction FuncGraphManager::Transact() { | ||||
| @@ -72,6 +72,7 @@ class PyFuncGraphFetcher { | |||||
| mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); | mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); | ||||
| if (doResolve_) { | if (doResolve_) { | ||||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false); | std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false); | ||||
| mindspore::parse::python_adapter::set_use_signature_in_resolve(false); | |||||
| mindspore::parse::ResolveAll(manager); | mindspore::parse::ResolveAll(manager); | ||||
| } | } | ||||
| return func_graph; | return func_graph; | ||||
| @@ -131,3 +131,26 @@ def test_while_in_while(): | |||||
| output = while_in_while(c1, c2, c3) | output = while_in_while(c1, c2, c3) | ||||
| expect = Tensor([1274], mstype.int32) | expect = Tensor([1274], mstype.int32) | ||||
| assert output == expect | assert output == expect | ||||
| @ms_function | |||||
| def while_by_while_in_while(x, y, z): | |||||
| out = c4 | |||||
| while x < c2: | |||||
| y = c4 + c4 | |||||
| while y < c2: | |||||
| y = y + 1 | |||||
| out = out + y | |||||
| z = c4 + c4 | |||||
| while z < c2: | |||||
| z = z + 1 | |||||
| out = out + z | |||||
| x = x + 1 | |||||
| out = out + x | |||||
| return out | |||||
| def test_while_by_while_in_while(): | |||||
| output = while_by_while_in_while(c1, c2, c3) | |||||
| expect = Tensor([350], mstype.int32) | |||||
| assert output == expect | |||||