diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index a256a61412..f9fcf527a5 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -2,7 +2,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ #include "ir/signature.h" #include "debug/trace.h" #include "utils/ms_context.h" +#include "utils/utils.h" namespace mindspore { // namespace to support composite operators definition @@ -184,7 +185,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraph return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); }); - inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); + auto call_node = func_graph->NewCNodeInOrder(inputs2); + call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true)); + inputs.push_back(call_node); } return func_graph->NewCNodeInOrder(inputs); } @@ -222,7 +225,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); }); - inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); + auto call_node = func_graph->NewCNodeInOrder(inputs2); + call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true)); + inputs.push_back(call_node); } return func_graph->NewCNodeInOrder(inputs); } @@ -253,7 +258,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap j++; } - inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); + auto call_node = func_graph->NewCNodeInOrder(inputs2); + call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true)); + inputs.push_back(call_node); } return func_graph->NewCNodeInOrder(inputs); } diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc new file mode 100644 index 0000000000..be31d880e9 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -0,0 +1,321 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/auto_monad_eliminate.h" + +#include +#include +#include +#include + +#include "base/core_ops.h" + +namespace mindspore { +namespace opt { +std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector &toposet, + std::vector *need_replace_loads) { + std::unordered_map load_groups_record; + std::vector> load_groups; + std::unordered_set unload_users_record; + for (size_t i = 0; i < toposet.size(); i++) { + auto &node = toposet[i]; + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { + for (const auto &input : cnode->inputs()) { + if (input->isa() || + (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast()->input(1)->isa())) { + unload_users_record.insert(input); + } + } + continue; + } + // Exclude free variable node. + if (cnode->func_graph() != fg) { + continue; + } + auto load_param = cnode->input(1); + // first time get same input1 of load. + if (load_groups_record.find(load_param) == load_groups_record.end()) { + load_groups_record[load_param] = load_groups.size(); + load_groups.push_back({i}); + if (unload_users_record.find(load_param) == unload_users_record.end()) { + need_replace_loads->emplace_back(cnode); + } + } else { + // not first time get same input1 of load + load_groups[load_groups_record[load_param]].push_back(i); + } + } + return load_groups; +} + +std::vector> SplitGroup(const std::vector &toposet, const std::vector &group) { + if (group.size() <= 1) { + return {}; + } + auto load_param = toposet[group.back()]->cast()->input(1); + size_t cur_load_index = 1; + size_t pre_load_index = 0; + std::vector cur_group = {group[pre_load_index]}; + std::vector> split_groups; + while (cur_load_index < group.size()) { + const auto &cur_load = group[cur_load_index]; + const auto &prev_load = group[pre_load_index]; + const auto param_used_by_other = + std::any_of(toposet.begin() + prev_load, toposet.begin() + cur_load, [&load_param](const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + if (IsPrimitiveCNode(node, prim::kPrimLoad)) { + return false; + } + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + return std::any_of(inputs.begin(), inputs.end(), + [&load_param](const AnfNodePtr &input) { return load_param == input; }); + }); + if (param_used_by_other) { + split_groups.push_back(cur_group); + cur_group.clear(); + } + cur_group.push_back(cur_load); + pre_load_index++; + cur_load_index++; + } + // push back the last splited group. + split_groups.push_back(cur_group); + return split_groups; +} + +// Pattern1====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// u3 = UpdateState(u2, b) +//==> +// delete the UpdateState +void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, + const AnfNodePtr &load) { + const auto &load_cnode = load->cast(); + const auto &u = load_cnode->input(2); + manager->Replace(load_user, u); +} + +// Pattern2====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// t = make_tuple(x, b) +// u3 = UpdateState(u2, t) +//==> +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// u3 = UpdateState(u2, x) +void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { + // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load. + AnfNodePtr other_input = load; + for (size_t i = 1; i < make_tuple->size(); i++) { + if (make_tuple->input(i) != load) { + other_input = make_tuple->input(i); + break; + } + } + MS_EXCEPTION_IF_NULL(other_input); + manager->Replace(make_tuple, other_input); +} + +// Pattern3====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// t = make_tuple(x, y, b, z) +// u3 = UpdateState(u2, t) +//==> +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// t = make_tuple(x, y, z) +// u3 = UpdateState(u2, t) +void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple, + const AnfNodePtr &load) { + auto &make_tuple_inputs = make_tuple->inputs(); + std::vector new_make_tuple_inputs; + (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs), + [load](const AnfNodePtr &input) { return load != input; }); + const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs); + new_make_tuple->set_abstract(make_tuple->abstract()); + manager->Replace(make_tuple, new_make_tuple); +} + +void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { + auto load_users = manager->node_users()[load]; + for (const auto &load_user : load_users) { + // Pattern1 + if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { + DeleteLoadUserUpdateState(manager, load_user.first, load); + continue; + } + if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { + const auto &make_tuple = load_user.first->cast(); + auto &maketuple_users = manager->node_users()[make_tuple]; + auto maketuple_as_input_of_update = + maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState); + if (!maketuple_as_input_of_update) { + continue; + } + // Pattern2 + if (make_tuple->size() == 3) { + DeleteLoadUserMakeTuple(manager, make_tuple, load); + continue; + } + // Pattern3 + if (make_tuple->size() > 3) { + ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); + } + } + } +} + +bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, + const std::vector &toposet, const std::vector &group) { + if (group.size() <= 1) { + return false; + } + const auto &main = toposet[group[0]]; + for (size_t i = 1; i < group.size(); i++) { + ReplaceLoadUser(manager, fg, toposet[group[i]]); + manager->Replace(toposet[group[i]], main); + } + return true; +} + +AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { + auto ¶ms = fg->parameters(); + auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend(); + auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); }); + if (iter != end) { + return *iter; + } + auto monad = NewValueNode(kUMonad); + monad->set_abstract(kUMonad->ToAbstract()); + return monad; +} + +// Replace UpdateStates with U for first load. +// Covert: +// u1 = UpdateState(u, c) +// p1 = Load(para1, u1) // first load for para1 +// To: +// u1 = UpdateState(u, c) +// p1 = Load(para1, u') // u' is first monad in graph or new monad +bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector &need_replace_loads) { + if (need_replace_loads.size() == 0) { + return false; + } + constexpr size_t second_input_index = 2; + auto monad = GetFirstMonad(fg); + for (const auto &load_node : need_replace_loads) { + if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) { + continue; + } + auto update_state = load_node->cast()->input(second_input_index); + if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) { + continue; + } + auto mgr = fg->manager(); + mgr->SetEdge(load_node, second_input_index, monad); + } + return true; +} + +// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => +// Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,... +bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { + auto changed = false; + for (const FuncGraphPtr &fg : manager->func_graphs()) { + std::vector toposet = TopoSort(fg->get_return()); + std::vector need_replace_loads; + std::vector> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); + const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); + if (update_state_replaced) { + changed = true; + } + // split group if there is no-load node between two load nodes. + std::vector> need_merge_loads; + for (auto &group : load_groups) { + auto groups = SplitGroup(toposet, group); + need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); + } + for (auto &group : need_merge_loads) { + const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); + if (!changed && replaced) { + changed = true; + } + } + } + MS_LOG(DEBUG) << "changed: " << changed; + return changed; +} + +// Eliminate auto monad node: +// From: +// u1 = UpdateState(...); +// xxx = User(u1); // Other users except below Depend. +// output = Depend(output, u1); +// return output; +// To: +// u1 = UpdateState(...); +// xxx = User(u1); +// return output; +bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const { + auto changed = false; + for (const FuncGraphPtr &fg : manager->func_graphs()) { + auto output = fg->output(); + if (output == nullptr) { + continue; + } + if (!IsPrimitiveCNode(output, prim::kPrimDepend)) { + continue; + } + constexpr size_t attach_index = 2; + auto attach = output->cast()->input(attach_index); + if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) { + continue; + } + auto &node_users = manager->node_users(); + auto iter = node_users.find(attach); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "No user of node: " << attach->DebugString(); + } + auto &users = iter->second; + if (users.size() <= 1) { + continue; + } + constexpr size_t input_index = 1; + auto input = output->cast()->input(input_index); + MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString(); + fg->set_output(input); + changed = true; + } + MS_LOG(DEBUG) << "changed: " << changed; + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.h b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.h new file mode 100644 index 0000000000..1f5f136e45 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_ + +#include "ir/anf.h" +#include "ir/manager.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class AutoMonadEliminator { + public: + AutoMonadEliminator() = default; + virtual ~AutoMonadEliminator() = default; + + bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { + auto manager = optimizer->resource()->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + // Never report change. + (void)ReplaceAutoMonadNode(manager); + (void)EliminateAutoMonadNode(manager); + return false; + } + + private: + bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const; + bool EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 0a2d5f124a..d9668307f1 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,13 +21,10 @@ #include #include #include -#include -#include #include "abstract/abstract_function.h" #include "utils/flags.h" #include "utils/utils.h" -#include "base/core_ops.h" namespace mindspore { /* namespace to support opt */ @@ -120,254 +117,6 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } -std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector &toposet, - std::vector *need_replace_loads) { - std::unordered_map load_groups_record; - std::vector> load_groups; - std::unordered_set unload_users_record; - for (size_t i = 0; i < toposet.size(); i++) { - auto &node = toposet[i]; - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { - for (const auto &input : cnode->inputs()) { - if (input->isa() || - (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast()->input(1)->isa())) { - unload_users_record.insert(input); - } - } - continue; - } - // Exclude free variable node. - if (cnode->func_graph() != fg) { - continue; - } - auto load_param = cnode->input(1); - // first time get same input1 of load. - if (load_groups_record.find(load_param) == load_groups_record.end()) { - load_groups_record[load_param] = load_groups.size(); - load_groups.push_back({i}); - if (unload_users_record.find(load_param) == unload_users_record.end()) { - need_replace_loads->emplace_back(cnode); - } - } else { - // not first time get same input1 of load - load_groups[load_groups_record[load_param]].push_back(i); - } - } - return load_groups; -} - -std::vector> SplitGroup(const std::vector &toposet, const std::vector &group) { - if (group.size() <= 1) { - return {}; - } - auto load_param = toposet[group.back()]->cast()->input(1); - size_t cur_load_index = 1; - size_t pre_load_index = 0; - std::vector cur_group = {group[pre_load_index]}; - std::vector> split_groups; - while (cur_load_index < group.size()) { - const auto &cur_load = group[cur_load_index]; - const auto &prev_load = group[pre_load_index]; - const auto param_used_by_other = - std::any_of(toposet.begin() + prev_load, toposet.begin() + cur_load, [&load_param](const AnfNodePtr &node) { - if (!node->isa()) { - return false; - } - if (IsPrimitiveCNode(node, prim::kPrimLoad)) { - return false; - } - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - return std::any_of(inputs.begin(), inputs.end(), - [&load_param](const AnfNodePtr &input) { return load_param == input; }); - }); - if (param_used_by_other) { - split_groups.push_back(cur_group); - cur_group.clear(); - } - cur_group.push_back(cur_load); - pre_load_index++; - cur_load_index++; - } - // push back the last splited group. - split_groups.push_back(cur_group); - return split_groups; -} - -// Pattern1====================================== -// a = Load(para1, u1) -// ... -// b = Load(para1, u2) -// u3 = UpdateState(u2, b) -//==> -// delete the UpdateState -void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, - const AnfNodePtr &load) { - const auto &load_cnode = load->cast(); - const auto &u = load_cnode->input(2); - manager->Replace(load_user, u); -} - -// Pattern2====================================== -// a = Load(para1, u1) -// ... -// b = Load(para1, u2) -// t = make_tuple(x, b) -// u3 = UpdateState(u2, t) -//==> -// a = Load(para1, u1) -// ... -// b = Load(para1, u2) -// u3 = UpdateState(u2, x) -void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { - // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load. - AnfNodePtr other_input = load; - for (size_t i = 1; i < make_tuple->size(); i++) { - if (make_tuple->input(i) != load) { - other_input = make_tuple->input(i); - break; - } - } - MS_EXCEPTION_IF_NULL(other_input); - manager->Replace(make_tuple, other_input); -} - -// Pattern3====================================== -// a = Load(para1, u1) -// ... -// b = Load(para1, u2) -// t = make_tuple(x, y, b, z) -// u3 = UpdateState(u2, t) -//==> -// a = Load(para1, u1) -// ... -// b = Load(para1, u2) -// t = make_tuple(x, y, z) -// u3 = UpdateState(u2, t) -void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple, - const AnfNodePtr &load) { - auto &make_tuple_inputs = make_tuple->inputs(); - std::vector new_make_tuple_inputs; - (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs), - [load](const AnfNodePtr &input) { return load != input; }); - const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs); - new_make_tuple->set_abstract(make_tuple->abstract()); - manager->Replace(make_tuple, new_make_tuple); -} - -void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { - auto load_users = manager->node_users()[load]; - for (const auto &load_user : load_users) { - // Pattern1 - if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { - DeleteLoadUserUpdateState(manager, load_user.first, load); - continue; - } - if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { - const auto &make_tuple = load_user.first->cast(); - auto &maketuple_users = manager->node_users()[make_tuple]; - auto maketuple_as_input_of_update = - maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState); - if (!maketuple_as_input_of_update) { - continue; - } - // Pattern2 - if (make_tuple->size() == 3) { - DeleteLoadUserMakeTuple(manager, make_tuple, load); - continue; - } - // Pattern3 - if (make_tuple->size() > 3) { - ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); - } - } - } -} - -bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, - const std::vector &toposet, const std::vector &group) { - if (group.size() <= 1) { - return false; - } - const auto &main = toposet[group[0]]; - for (size_t i = 1; i < group.size(); i++) { - ReplaceLoadUser(manager, fg, toposet[group[i]]); - manager->Replace(toposet[group[i]], main); - } - return true; -} - -AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { - auto ¶ms = fg->parameters(); - auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend(); - auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); }); - if (iter != end) { - return *iter; - } - auto monad = NewValueNode(kUMonad); - monad->set_abstract(kUMonad->ToAbstract()); - return monad; -} - -// Replace UpdateStates with U for first load. -// Covert: -// u1 = UpdateState(u, c) -// p1 = Load(para1, u1) // first load for para1 -// To: -// u1 = UpdateState(u, c) -// p1 = Load(para1, u') // u' is first monad in graph or new monad -bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector &need_replace_loads) { - if (need_replace_loads.size() == 0) { - return false; - } - constexpr size_t second_input_index = 2; - auto monad = GetFirstMonad(fg); - for (const auto &load_node : need_replace_loads) { - if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) { - continue; - } - auto update_state = load_node->cast()->input(second_input_index); - if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) { - continue; - } - auto mgr = fg->manager(); - mgr->SetEdge(load_node, second_input_index, monad); - } - return true; -} - -// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => -// Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,... -bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { - auto changed = false; - for (const FuncGraphPtr &fg : manager->func_graphs()) { - std::vector toposet = TopoSort(fg->get_return()); - std::vector need_replace_loads; - std::vector> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); - const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); - if (update_state_replaced) { - changed = true; - } - // split group if there is no-load node between two load nodes. - std::vector> need_merge_loads; - for (auto &group : load_groups) { - auto groups = SplitGroup(toposet, group); - need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); - } - for (auto &group : need_merge_loads) { - const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); - if (!changed && replaced) { - changed = true; - } - } - } - return changed; -} - // The op like print, summary, or the op do not has true output, and always as a depend node input. static bool HasSideEffect(const AnfNodePtr &node) { auto prim = GetCNodePrimitive(node); @@ -507,9 +256,7 @@ bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vectorAddFuncGraph(root); - auto change1 = ReplaceAutoMonadNode(manager); - auto change2 = BuildOrderGroupAndDoReplace(manager); - return change1 || change2; + return BuildOrderGroupAndDoReplace(manager); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/cse.h b/mindspore/ccsrc/frontend/optimizer/cse.h index 12341bbfb8..b578c6c1b2 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.h +++ b/mindspore/ccsrc/frontend/optimizer/cse.h @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,6 @@ class CSE { private: bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; - bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const; bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, std::unordered_map> *groups) const; }; diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index bf6be857fe..3af722277c 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -33,6 +33,7 @@ #include "frontend/optimizer/clean.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/graph_transform.h" +#include "frontend/optimizer/auto_monad_eliminate.h" #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/step_auto_parallel.h" #include "frontend/parallel/cache_embedding/cache_embedding.h" @@ -183,6 +184,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { {"a_after_grad", a_after_grad}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)}, + {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())}, {"cse", opt::OptPassConfig(opt::CSEPass(false))}, {"a_3", a_3}}); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc index bdacaa3165..5cc124df9f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc @@ -27,6 +27,7 @@ #include "frontend/operator/ops.h" #include "frontend/operator/composite/multitype_funcgraph.h" #include "utils/flags.h" +#include "utils/utils.h" #include "utils/ordered_map.h" #include "base/core_ops.h" #include "abstract/abstract_value.h" @@ -1291,6 +1292,14 @@ class AutoMonadConverter { } AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) { + // Not attach UpdateState if set kAttrIgnoreSideEffect. + auto attr_ignore_side_effect = attach->cast()->GetAttr(kAttrIgnoreSideEffect); + auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa() && + GetValue(attr_ignore_side_effect); + if (ignore_side_effect) { + return state; + } + auto update_state = NewValueNode(prim::kPrimUpdateState); auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach}); update_state_cnode->set_abstract(state->abstract()); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index a771e53e4d..9b8626b7bc 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -407,6 +407,7 @@ constexpr auto kAttrParallelTypeInfo = "parallel_type_info"; constexpr auto kAttrCompositeType = "composite_type"; constexpr auto kAttrStitch = "stitch"; constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; +constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect"; constexpr auto kAttrSwitchLayer = "switch_layer"; constexpr auto kAttrReturn = "return";