| @@ -52,9 +52,10 @@ std::string GetNodeRepr(AnfNodePtr node) { | |||
| void ResolveFuncGraph_(const FuncGraphPtr &fg) { | |||
| auto manager = Manage(fg, false); | |||
| auto use_sig = parse::python_adapter::UseSignatureInResolve(); | |||
| parse::python_adapter::set_use_signature_in_resolve(false); | |||
| parse::ResolveAll(manager); | |||
| parse::python_adapter::set_use_signature_in_resolve(true); | |||
| parse::python_adapter::set_use_signature_in_resolve(use_sig); | |||
| } | |||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | |||
| @@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb | |||
| void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { | |||
| std::string var = phi_nodes_[phi]; | |||
| MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; | |||
| auto removable = CollectRemovablePhi(phi); | |||
| // If the phi node is not necessary, not need to add to jumps_ of the prev blocks. | |||
| if (removable) { | |||
| MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var; | |||
| return; | |||
| } | |||
| for (auto &pred : prev_blocks_) { | |||
| MS_EXCEPTION_IF_NULL(pred); | |||
| MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); | |||
| @@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { | |||
| CNodePtr jump = pred->jumps_[this]; | |||
| jump->add_input(arg_node); | |||
| } | |||
| // If the phi node in the body part of a for/while loop is being removed, | |||
| // then the closure convert phase will generate a cycle in graph if the | |||
| // loop is kept after specialization. This should be investigate further. | |||
| // Just now user has to set a flag on a function to indicate the for loop | |||
| // will definitely can be unroll as the sequence in for statement is fixed | |||
| // size in compile time. | |||
| if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || | |||
| parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| CollectRemovablePhi(phi); | |||
| } | |||
| } | |||
| AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { | |||
| @@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame | |||
| // 2. it's costly to iterate the graph to replace the phi for each phi. | |||
| // Args : | |||
| // phi : This parameter node is functioning as a phi node. | |||
| void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||
| bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||
| MS_EXCEPTION_IF_NULL(phi); | |||
| std::string var = phi_nodes_[phi]; | |||
| MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); | |||
| MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var; | |||
| if (prev_blocks_.size() == 0) { | |||
| MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); | |||
| return; | |||
| MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var; | |||
| return false; | |||
| } | |||
| AnfNodePtr arg_node = SearchReplaceNode(var, phi); | |||
| if (arg_node != nullptr) { | |||
| @@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||
| const auto ¶m = phi_iter.second->cast<ParameterPtr>(); | |||
| if (param == phi) { | |||
| MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() | |||
| << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); | |||
| << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString() | |||
| << " in graph " << arg_node->func_graph()->ToString(); | |||
| prev->removable_phis_[phi_iter.first] = arg_node; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| // A block should be marked matured if its predecessor blocks have been processed | |||
| @@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||
| AnfNodePtr ReadVariable(const std::string &var_name); | |||
| void AddPrevBlock(const FunctionBlockPtr &block); | |||
| void SetPhiArgument(const ParameterPtr &phi); | |||
| void CollectRemovablePhi(const ParameterPtr &phi); | |||
| bool CollectRemovablePhi(const ParameterPtr &phi); | |||
| // A block is matured if all its predecessors is generated | |||
| void Mature(); | |||
| CNodePtr ForceToBoolNode(const AnfNodePtr &cond); | |||
| @@ -1436,6 +1436,15 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje | |||
| return block; | |||
| } | |||
| AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) { | |||
| const auto &inp = node->cast<ParameterPtr>(); | |||
| const auto &iter = removable_phis.find(inp); | |||
| if (iter == removable_phis.end()) { | |||
| return node; | |||
| } | |||
| return FindPhis(removable_phis, iter->second); | |||
| } | |||
| void Parser::RemoveUnnecessaryPhis() { | |||
| // merge all removable phis to one map; | |||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis; | |||
| @@ -1443,28 +1452,39 @@ void Parser::RemoveUnnecessaryPhis() { | |||
| MS_EXCEPTION_IF_NULL(block); | |||
| removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); | |||
| } | |||
| if (removable_phis.size() == 0) { | |||
| return; | |||
| } | |||
| for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { | |||
| if (node->isa<CNode>()) { | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| for (std::size_t i = 0; i < inputs.size(); i++) { | |||
| if (inputs[i]->isa<Parameter>()) { | |||
| const auto &inp = inputs[i]->cast<ParameterPtr>(); | |||
| const auto &iter = removable_phis.find(inp); | |||
| if (iter == removable_phis.end()) { | |||
| continue; | |||
| } | |||
| auto &argNode = iter->second; | |||
| MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " | |||
| << cnode->DebugString() << " with " << argNode->DebugString(); | |||
| cnode->set_input(i, argNode); | |||
| } | |||
| } | |||
| auto fg_name = func_graph_->ToString(); | |||
| auto mng = Manage(func_graph_, false); | |||
| // replace the nodes | |||
| for (auto iter : removable_phis) { | |||
| auto new_node = FindPhis(removable_phis, iter.first); | |||
| MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString(); | |||
| mng->Replace(iter.first, new_node); | |||
| } | |||
| // remove the parameter | |||
| for (FunctionBlockPtr &block : func_block_list_) { | |||
| MS_EXCEPTION_IF_NULL(block); | |||
| auto &local_removable_phis = block->removable_phis(); | |||
| if (local_removable_phis.size() == 0) { | |||
| continue; | |||
| } | |||
| auto func_graph = block->func_graph(); | |||
| auto ¶meters = func_graph->parameters(); | |||
| std::vector<AnfNodePtr> new_parameters(parameters.size()); | |||
| auto it = std::copy_if( | |||
| parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) { | |||
| return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end(); | |||
| }); | |||
| // shrink container to new size | |||
| new_parameters.resize(std::distance(new_parameters.begin(), it)); | |||
| func_graph->set_parameters(new_parameters); | |||
| } | |||
| for (auto fg : mng->func_graphs()) { | |||
| fg->ClearAllManagerInfo(); | |||
| } | |||
| } | |||
| @@ -111,6 +111,27 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) { | |||
| return sorted_nodes; | |||
| } | |||
| std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) { | |||
| std::deque<FuncGraphPtr> todo; | |||
| todo.push_back(root); | |||
| std::vector<FuncGraphPtr> sorted; | |||
| auto seen = NewSeenGeneration(); | |||
| while (!todo.empty()) { | |||
| FuncGraphPtr top = todo.front(); | |||
| todo.pop_front(); | |||
| sorted.push_back(top); | |||
| auto used = top->func_graphs_used(); | |||
| for (auto &item : used) { | |||
| if (item.first->seen_ == seen) { | |||
| continue; | |||
| } | |||
| todo.push_back(item.first); | |||
| item.first->seen_ = seen; | |||
| } | |||
| } | |||
| return sorted; | |||
| } | |||
| std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | |||
| std::vector<AnfNodePtr> vecs; | |||
| if (node == nullptr) { | |||
| @@ -70,6 +70,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = | |||
| const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret); | |||
| std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root); | |||
| class FuncGraphIndex { | |||
| public: | |||
| explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | |||
| @@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const { | |||
| return buffer.str(); | |||
| } | |||
| std::string Parameter::DebugString(int recursive_level) const { | |||
| std::ostringstream buffer; | |||
| if (recursive_level > 0) { | |||
| if (func_graph() != nullptr) { | |||
| buffer << func_graph()->ToString() << ":"; | |||
| } | |||
| } | |||
| buffer << ToString(); | |||
| return buffer.str(); | |||
| } | |||
| std::string ValueNode::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(value_); | |||
| if (value_->isa<FuncGraph>()) { | |||
| @@ -249,7 +249,7 @@ class Parameter : public ANode { | |||
| MS_DECLARE_PARENT(Parameter, ANode); | |||
| void accept(AnfVisitor *v) override; | |||
| std::string DebugString(int recursive_level = 1) const override; | |||
| std::string name() const { return name_; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| std::string fullname_with_scope() override { return name(); }; | |||
| @@ -417,6 +417,15 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() { | |||
| return mng->recursive_graphs(shared_from_base<FuncGraph>()); | |||
| } | |||
| void FuncGraph::ClearAllManagerInfo() { | |||
| ClearNodes(); | |||
| ClearValueNodes(); | |||
| ClearFuncGraphCNodesIndex(); | |||
| ClearFreeVariables(); | |||
| ClearFuncGraphsUsed(); | |||
| ClearJFuncGraphs(); | |||
| } | |||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { | |||
| auto itr = this->parameter_default_value_.find(name); | |||
| if (itr == parameter_default_value_.end()) { | |||
| @@ -229,7 +229,8 @@ class FuncGraph : public FuncGraphBase { | |||
| } | |||
| this->debug_info_ = info; | |||
| } | |||
| // clear all info from manager | |||
| void ClearAllManagerInfo(); | |||
| // get all nodes belonging to this func graph | |||
| const AnfNodeSet &nodes(); | |||
| void CopyNodes(const FuncGraphPtr &source); | |||
| @@ -25,6 +25,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/profile.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/graph_utils.h" | |||
| // namespace to support intermediate representation definition | |||
| namespace mindspore { | |||
| @@ -400,11 +401,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph | |||
| } | |||
| void Cloner::Lift() { | |||
| for (auto &func_graph_params : repl_func_graph_params_) { | |||
| auto &func_graph = func_graph_params.first; | |||
| auto ¶ms = func_graph_params.second; | |||
| for (auto &cnode : func_graph->func_graph_cnodes_index()) { | |||
| LiftParameters(cnode.first->first->func_graph(), func_graph, params); | |||
| // lift inner graph first | |||
| auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin())); | |||
| for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) { | |||
| auto func_graph = *r_iter; | |||
| auto iter = repl_func_graph_params_.find(func_graph); | |||
| if (iter != repl_func_graph_params_.end()) { | |||
| auto ¶ms = iter->second; | |||
| for (auto &cnode : func_graph->func_graph_cnodes_index()) { | |||
| LiftParameters(cnode.first->first->func_graph(), func_graph, params); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -520,12 +520,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||
| target->CopyFuncGraphsUsed(source); | |||
| target->CopyJFuncGraphs(source); | |||
| signals_->InvalidateComputer(); | |||
| source->ClearNodes(); | |||
| source->ClearValueNodes(); | |||
| source->ClearFuncGraphCNodesIndex(); | |||
| source->ClearFreeVariables(); | |||
| source->ClearFuncGraphsUsed(); | |||
| source->ClearJFuncGraphs(); | |||
| source->ClearAllManagerInfo(); | |||
| } | |||
| FuncGraphTransaction FuncGraphManager::Transact() { | |||
| @@ -72,6 +72,7 @@ class PyFuncGraphFetcher { | |||
| mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); | |||
| if (doResolve_) { | |||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false); | |||
| mindspore::parse::python_adapter::set_use_signature_in_resolve(false); | |||
| mindspore::parse::ResolveAll(manager); | |||
| } | |||
| return func_graph; | |||
| @@ -131,3 +131,26 @@ def test_while_in_while(): | |||
| output = while_in_while(c1, c2, c3) | |||
| expect = Tensor([1274], mstype.int32) | |||
| assert output == expect | |||
| @ms_function | |||
| def while_by_while_in_while(x, y, z): | |||
| out = c4 | |||
| while x < c2: | |||
| y = c4 + c4 | |||
| while y < c2: | |||
| y = y + 1 | |||
| out = out + y | |||
| z = c4 + c4 | |||
| while z < c2: | |||
| z = z + 1 | |||
| out = out + z | |||
| x = x + 1 | |||
| out = out + x | |||
| return out | |||
| def test_while_by_while_in_while(): | |||
| output = while_by_while_in_while(c1, c2, c3) | |||
| expect = Tensor([350], mstype.int32) | |||
| assert output == expect | |||