This reverts commit d0d0cc2c71.
tags/v1.5.0-rc1
| @@ -122,7 +122,7 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf | |||
| return nullptr; | |||
| } | |||
| auto bn_infer = CreateBNInfer(graph, batchnorm, node); | |||
| TransferDependOrUpdateState(batchnorm, graph, bn_infer); | |||
| TransferDepend(batchnorm, graph, bn_infer); | |||
| return bn_infer; | |||
| } | |||
| } // namespace opt | |||
| @@ -125,7 +125,7 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c | |||
| return nullptr; | |||
| } | |||
| auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node); | |||
| TransferDependOrUpdateState(batchnorm_grad, graph, bn_infer_grad); | |||
| TransferDepend(batchnorm_grad, graph, bn_infer_grad); | |||
| return bn_infer_grad; | |||
| } | |||
| } // namespace opt | |||
| @@ -916,34 +916,21 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { | |||
| return new_value_node; | |||
| } | |||
| void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { | |||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| // Find BatchNorm's output which is a Depend or UpdateState. | |||
| auto node_users = manager->node_users()[old_node]; | |||
| for (const auto &node_index : node_users) { | |||
| for (const auto &node_index : manager->node_users()[old_node]) { | |||
| AnfNodePtr output = node_index.first; | |||
| size_t index = IntToSize(node_index.second); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || | |||
| AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { | |||
| auto output_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| auto inputs = output_cnode->inputs(); | |||
| std::vector<AnfNodePtr> new_inputs{output_cnode->input(0)}; | |||
| for (size_t i = 1; i < inputs.size(); i++) { | |||
| auto input = inputs[i]; | |||
| if (input == old_node) { | |||
| new_inputs.emplace_back(new_node); | |||
| } else { | |||
| new_inputs.emplace_back(input); | |||
| } | |||
| } | |||
| auto new_output = graph->NewCNode(new_inputs); | |||
| new_output->set_abstract(output->abstract()); | |||
| new_output->set_scope(output->scope()); | |||
| manager->Replace(output, new_output); | |||
| auto depend = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend); | |||
| depend->set_input(index, new_node); | |||
| } | |||
| } | |||
| } | |||
| @@ -216,8 +216,8 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor | |||
| // Create a new value node of func graph,not kernel graph | |||
| ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | |||
| // Transfer depend or updatestate to the new node | |||
| void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | |||
| // Transfer depend to the new node | |||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | |||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | |||
| @@ -76,17 +76,15 @@ class OrderEnforcer { | |||
| } | |||
| } | |||
| std::unordered_set<AnfNodePtr> CheckMakeTupleHaveLoad(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::unordered_set<AnfNodePtr> loads; | |||
| bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) { | |||
| auto inputs = cnode->inputs(); | |||
| for (size_t index = 1; index < inputs.size(); index++) { | |||
| auto input = cnode->input(index); | |||
| if (IsPrimitiveCNode(input, prim::kPrimLoad)) { | |||
| loads.insert(input); | |||
| return true; | |||
| } | |||
| } | |||
| return loads; | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) { | |||
| @@ -157,31 +155,23 @@ class OrderEnforcer { | |||
| // u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs | |||
| // assign = Assign(para2, inputs, u3) | |||
| void HandleMakeTupleUsers(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto maketuple = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(maketuple); | |||
| std::unordered_set<AnfNodePtr> loads = CheckMakeTupleHaveLoad(maketuple); | |||
| if (!loads.empty()) { | |||
| if (CheckMakeTupleHaveLoad(maketuple)) { | |||
| auto update_state = FindLastUpdateState(maketuple); | |||
| if (update_state != nullptr) { | |||
| std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple); | |||
| std::unordered_set<AnfNodePtr> no_push_all_users; | |||
| std::unordered_set<AnfNodePtr> no_push_maketuple_users; | |||
| // Push and Pull at the end of the execution order, | |||
| // In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate | |||
| for (auto maketuple_user : maketuple_users) { | |||
| if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) { | |||
| no_push_all_users.insert(maketuple_user); | |||
| } | |||
| } | |||
| for (auto load : loads) { | |||
| std::unordered_set<AnfNodePtr> load_users = GetSpecialOperatorRealUsers(load); | |||
| for (auto load_user : load_users) { | |||
| no_push_all_users.insert(load_user); | |||
| no_push_maketuple_users.insert(maketuple_user); | |||
| } | |||
| } | |||
| auto update_state_cnode = update_state->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(update_state_cnode); | |||
| AddInputEdges(update_state_cnode, no_push_all_users); | |||
| AddInputEdges(update_state_cnode, no_push_maketuple_users); | |||
| } | |||
| } | |||
| } | |||
| @@ -275,8 +265,6 @@ class OrderEnforcer { | |||
| // Add load users as input edges of the update_state node. | |||
| void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) { | |||
| auto sorted_load_users = SortLoadUsers(load_users); | |||
| auto inputs = update_state->inputs(); | |||
| size_t origin_size = inputs.size(); | |||
| for (auto &load_user : sorted_load_users) { | |||
| if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { | |||
| continue; | |||
| @@ -284,16 +272,10 @@ class OrderEnforcer { | |||
| if (!IsDependOn(load_user, update_state)) { | |||
| processed_nodes_.insert(load_user); | |||
| if (!IsInUpdateState(load_user, update_state)) { | |||
| inputs.emplace_back(load_user); | |||
| manager_->AddEdge(update_state, load_user); | |||
| } | |||
| } | |||
| } | |||
| if (inputs.size() > origin_size) { | |||
| auto new_update_state = func_graph_->NewCNode(inputs); | |||
| new_update_state->set_abstract(update_state->abstract()); | |||
| new_update_state->set_scope(update_state->scope()); | |||
| manager_->Replace(update_state, new_update_state); | |||
| } | |||
| } | |||
| // Sort load users by their topo sort order. | |||
| @@ -157,9 +157,7 @@ def test_if_after_if_in_if(): | |||
| control_flow_if_after_if_in_if(IfAfterIfInIfNet, x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.skip(reason="not supported side effect") | |||
| def test_if_after_if_in_if_01(): | |||
| x = Tensor(2, mstype.int32) | |||
| control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x) | |||