From: @zh_qh Reviewed-by: Signed-off-by:pull/13050/MERGE
| @@ -2,7 +2,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -35,6 +35,7 @@ | |||
| #include "ir/signature.h" | |||
| #include "debug/trace.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| @@ -184,7 +185,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph | |||
| return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); | |||
| }); | |||
| inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); | |||
| auto call_node = func_graph->NewCNodeInOrder(inputs2); | |||
| call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true)); | |||
| inputs.push_back(call_node); | |||
| } | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| @@ -222,7 +225,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap | |||
| return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); | |||
| }); | |||
| inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); | |||
| auto call_node = func_graph->NewCNodeInOrder(inputs2); | |||
| call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true)); | |||
| inputs.push_back(call_node); | |||
| } | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| @@ -253,7 +258,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGrap | |||
| j++; | |||
| } | |||
| inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); | |||
| auto call_node = func_graph->NewCNodeInOrder(inputs2); | |||
| call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true)); | |||
| inputs.push_back(call_node); | |||
| } | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| @@ -0,0 +1,321 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/auto_monad_eliminate.h" | |||
| #include <vector> | |||
| #include <unordered_set> | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet, | |||
| std::vector<AnfNodePtr> *need_replace_loads) { | |||
| std::unordered_map<AnfNodePtr, size_t> load_groups_record; | |||
| std::vector<std::vector<size_t>> load_groups; | |||
| std::unordered_set<AnfNodePtr> unload_users_record; | |||
| for (size_t i = 0; i < toposet.size(); i++) { | |||
| auto &node = toposet[i]; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { | |||
| for (const auto &input : cnode->inputs()) { | |||
| if (input->isa<Parameter>() || | |||
| (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) { | |||
| unload_users_record.insert(input); | |||
| } | |||
| } | |||
| continue; | |||
| } | |||
| // Exclude free variable node. | |||
| if (cnode->func_graph() != fg) { | |||
| continue; | |||
| } | |||
| auto load_param = cnode->input(1); | |||
| // first time get same input1 of load. | |||
| if (load_groups_record.find(load_param) == load_groups_record.end()) { | |||
| load_groups_record[load_param] = load_groups.size(); | |||
| load_groups.push_back({i}); | |||
| if (unload_users_record.find(load_param) == unload_users_record.end()) { | |||
| need_replace_loads->emplace_back(cnode); | |||
| } | |||
| } else { | |||
| // not first time get same input1 of load | |||
| load_groups[load_groups_record[load_param]].push_back(i); | |||
| } | |||
| } | |||
| return load_groups; | |||
| } | |||
| std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) { | |||
| if (group.size() <= 1) { | |||
| return {}; | |||
| } | |||
| auto load_param = toposet[group.back()]->cast<CNodePtr>()->input(1); | |||
| size_t cur_load_index = 1; | |||
| size_t pre_load_index = 0; | |||
| std::vector<size_t> cur_group = {group[pre_load_index]}; | |||
| std::vector<std::vector<size_t>> split_groups; | |||
| while (cur_load_index < group.size()) { | |||
| const auto &cur_load = group[cur_load_index]; | |||
| const auto &prev_load = group[pre_load_index]; | |||
| const auto param_used_by_other = | |||
| std::any_of(toposet.begin() + prev_load, toposet.begin() + cur_load, [&load_param](const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimLoad)) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| return std::any_of(inputs.begin(), inputs.end(), | |||
| [&load_param](const AnfNodePtr &input) { return load_param == input; }); | |||
| }); | |||
| if (param_used_by_other) { | |||
| split_groups.push_back(cur_group); | |||
| cur_group.clear(); | |||
| } | |||
| cur_group.push_back(cur_load); | |||
| pre_load_index++; | |||
| cur_load_index++; | |||
| } | |||
| // push back the last splited group. | |||
| split_groups.push_back(cur_group); | |||
| return split_groups; | |||
| } | |||
| // Pattern1====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // u3 = UpdateState(u2, b) | |||
| //==> | |||
| // delete the UpdateState | |||
| void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, | |||
| const AnfNodePtr &load) { | |||
| const auto &load_cnode = load->cast<CNodePtr>(); | |||
| const auto &u = load_cnode->input(2); | |||
| manager->Replace(load_user, u); | |||
| } | |||
| // Pattern2====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, b) | |||
| // u3 = UpdateState(u2, t) | |||
| //==> | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // u3 = UpdateState(u2, x) | |||
| void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { | |||
| // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load. | |||
| AnfNodePtr other_input = load; | |||
| for (size_t i = 1; i < make_tuple->size(); i++) { | |||
| if (make_tuple->input(i) != load) { | |||
| other_input = make_tuple->input(i); | |||
| break; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(other_input); | |||
| manager->Replace(make_tuple, other_input); | |||
| } | |||
| // Pattern3====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, y, b, z) | |||
| // u3 = UpdateState(u2, t) | |||
| //==> | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, y, z) | |||
| // u3 = UpdateState(u2, t) | |||
| void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple, | |||
| const AnfNodePtr &load) { | |||
| auto &make_tuple_inputs = make_tuple->inputs(); | |||
| std::vector<AnfNodePtr> new_make_tuple_inputs; | |||
| (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs), | |||
| [load](const AnfNodePtr &input) { return load != input; }); | |||
| const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs); | |||
| new_make_tuple->set_abstract(make_tuple->abstract()); | |||
| manager->Replace(make_tuple, new_make_tuple); | |||
| } | |||
| void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { | |||
| auto load_users = manager->node_users()[load]; | |||
| for (const auto &load_user : load_users) { | |||
| // Pattern1 | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { | |||
| DeleteLoadUserUpdateState(manager, load_user.first, load); | |||
| continue; | |||
| } | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { | |||
| const auto &make_tuple = load_user.first->cast<CNodePtr>(); | |||
| auto &maketuple_users = manager->node_users()[make_tuple]; | |||
| auto maketuple_as_input_of_update = | |||
| maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState); | |||
| if (!maketuple_as_input_of_update) { | |||
| continue; | |||
| } | |||
| // Pattern2 | |||
| if (make_tuple->size() == 3) { | |||
| DeleteLoadUserMakeTuple(manager, make_tuple, load); | |||
| continue; | |||
| } | |||
| // Pattern3 | |||
| if (make_tuple->size() > 3) { | |||
| ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, | |||
| const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) { | |||
| if (group.size() <= 1) { | |||
| return false; | |||
| } | |||
| const auto &main = toposet[group[0]]; | |||
| for (size_t i = 1; i < group.size(); i++) { | |||
| ReplaceLoadUser(manager, fg, toposet[group[i]]); | |||
| manager->Replace(toposet[group[i]], main); | |||
| } | |||
| return true; | |||
| } | |||
| AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { | |||
| auto ¶ms = fg->parameters(); | |||
| auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend(); | |||
| auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); }); | |||
| if (iter != end) { | |||
| return *iter; | |||
| } | |||
| auto monad = NewValueNode(kUMonad); | |||
| monad->set_abstract(kUMonad->ToAbstract()); | |||
| return monad; | |||
| } | |||
| // Replace UpdateStates with U for first load. | |||
| // Covert: | |||
| // u1 = UpdateState(u, c) | |||
| // p1 = Load(para1, u1) // first load for para1 | |||
| // To: | |||
| // u1 = UpdateState(u, c) | |||
| // p1 = Load(para1, u') // u' is first monad in graph or new monad | |||
| bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) { | |||
| if (need_replace_loads.size() == 0) { | |||
| return false; | |||
| } | |||
| constexpr size_t second_input_index = 2; | |||
| auto monad = GetFirstMonad(fg); | |||
| for (const auto &load_node : need_replace_loads) { | |||
| if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) { | |||
| continue; | |||
| } | |||
| auto update_state = load_node->cast<CNodePtr>()->input(second_input_index); | |||
| if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) { | |||
| continue; | |||
| } | |||
| auto mgr = fg->manager(); | |||
| mgr->SetEdge(load_node, second_input_index, monad); | |||
| } | |||
| return true; | |||
| } | |||
| // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => | |||
| // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,... | |||
| bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { | |||
| auto changed = false; | |||
| for (const FuncGraphPtr &fg : manager->func_graphs()) { | |||
| std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); | |||
| std::vector<AnfNodePtr> need_replace_loads; | |||
| std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); | |||
| const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); | |||
| if (update_state_replaced) { | |||
| changed = true; | |||
| } | |||
| // split group if there is no-load node between two load nodes. | |||
| std::vector<std::vector<size_t>> need_merge_loads; | |||
| for (auto &group : load_groups) { | |||
| auto groups = SplitGroup(toposet, group); | |||
| need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); | |||
| } | |||
| for (auto &group : need_merge_loads) { | |||
| const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); | |||
| if (!changed && replaced) { | |||
| changed = true; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "changed: " << changed; | |||
| return changed; | |||
| } | |||
| // Eliminate auto monad node: | |||
| // From: | |||
| // u1 = UpdateState(...); | |||
| // xxx = User(u1); // Other users except below Depend. | |||
| // output = Depend(output, u1); | |||
| // return output; | |||
| // To: | |||
| // u1 = UpdateState(...); | |||
| // xxx = User(u1); | |||
| // return output; | |||
| bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const { | |||
| auto changed = false; | |||
| for (const FuncGraphPtr &fg : manager->func_graphs()) { | |||
| auto output = fg->output(); | |||
| if (output == nullptr) { | |||
| continue; | |||
| } | |||
| if (!IsPrimitiveCNode(output, prim::kPrimDepend)) { | |||
| continue; | |||
| } | |||
| constexpr size_t attach_index = 2; | |||
| auto attach = output->cast<CNodePtr>()->input(attach_index); | |||
| if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) { | |||
| continue; | |||
| } | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(attach); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "No user of node: " << attach->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| if (users.size() <= 1) { | |||
| continue; | |||
| } | |||
| constexpr size_t input_index = 1; | |||
| auto input = output->cast<CNodePtr>()->input(input_index); | |||
| MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString(); | |||
| fg->set_output(input); | |||
| changed = true; | |||
| } | |||
| MS_LOG(DEBUG) << "changed: " << changed; | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_ | |||
| #include "ir/anf.h" | |||
| #include "ir/manager.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AutoMonadEliminator { | |||
| public: | |||
| AutoMonadEliminator() = default; | |||
| virtual ~AutoMonadEliminator() = default; | |||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||
| auto manager = optimizer->resource()->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| // Never report change. | |||
| (void)ReplaceAutoMonadNode(manager); | |||
| (void)EliminateAutoMonadNode(manager); | |||
| return false; | |||
| } | |||
| private: | |||
| bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const; | |||
| bool EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_ | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -21,13 +21,10 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <algorithm> | |||
| #include "abstract/abstract_function.h" | |||
| #include "utils/flags.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| @@ -120,254 +117,6 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { | |||
| return changed; | |||
| } | |||
| std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet, | |||
| std::vector<AnfNodePtr> *need_replace_loads) { | |||
| std::unordered_map<AnfNodePtr, size_t> load_groups_record; | |||
| std::vector<std::vector<size_t>> load_groups; | |||
| std::unordered_set<AnfNodePtr> unload_users_record; | |||
| for (size_t i = 0; i < toposet.size(); i++) { | |||
| auto &node = toposet[i]; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { | |||
| for (const auto &input : cnode->inputs()) { | |||
| if (input->isa<Parameter>() || | |||
| (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) { | |||
| unload_users_record.insert(input); | |||
| } | |||
| } | |||
| continue; | |||
| } | |||
| // Exclude free variable node. | |||
| if (cnode->func_graph() != fg) { | |||
| continue; | |||
| } | |||
| auto load_param = cnode->input(1); | |||
| // first time get same input1 of load. | |||
| if (load_groups_record.find(load_param) == load_groups_record.end()) { | |||
| load_groups_record[load_param] = load_groups.size(); | |||
| load_groups.push_back({i}); | |||
| if (unload_users_record.find(load_param) == unload_users_record.end()) { | |||
| need_replace_loads->emplace_back(cnode); | |||
| } | |||
| } else { | |||
| // not first time get same input1 of load | |||
| load_groups[load_groups_record[load_param]].push_back(i); | |||
| } | |||
| } | |||
| return load_groups; | |||
| } | |||
| std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) { | |||
| if (group.size() <= 1) { | |||
| return {}; | |||
| } | |||
| auto load_param = toposet[group.back()]->cast<CNodePtr>()->input(1); | |||
| size_t cur_load_index = 1; | |||
| size_t pre_load_index = 0; | |||
| std::vector<size_t> cur_group = {group[pre_load_index]}; | |||
| std::vector<std::vector<size_t>> split_groups; | |||
| while (cur_load_index < group.size()) { | |||
| const auto &cur_load = group[cur_load_index]; | |||
| const auto &prev_load = group[pre_load_index]; | |||
| const auto param_used_by_other = | |||
| std::any_of(toposet.begin() + prev_load, toposet.begin() + cur_load, [&load_param](const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimLoad)) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| return std::any_of(inputs.begin(), inputs.end(), | |||
| [&load_param](const AnfNodePtr &input) { return load_param == input; }); | |||
| }); | |||
| if (param_used_by_other) { | |||
| split_groups.push_back(cur_group); | |||
| cur_group.clear(); | |||
| } | |||
| cur_group.push_back(cur_load); | |||
| pre_load_index++; | |||
| cur_load_index++; | |||
| } | |||
| // push back the last splited group. | |||
| split_groups.push_back(cur_group); | |||
| return split_groups; | |||
| } | |||
| // Pattern1====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // u3 = UpdateState(u2, b) | |||
| //==> | |||
| // delete the UpdateState | |||
| void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, | |||
| const AnfNodePtr &load) { | |||
| const auto &load_cnode = load->cast<CNodePtr>(); | |||
| const auto &u = load_cnode->input(2); | |||
| manager->Replace(load_user, u); | |||
| } | |||
| // Pattern2====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, b) | |||
| // u3 = UpdateState(u2, t) | |||
| //==> | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // u3 = UpdateState(u2, x) | |||
| void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { | |||
| // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load. | |||
| AnfNodePtr other_input = load; | |||
| for (size_t i = 1; i < make_tuple->size(); i++) { | |||
| if (make_tuple->input(i) != load) { | |||
| other_input = make_tuple->input(i); | |||
| break; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(other_input); | |||
| manager->Replace(make_tuple, other_input); | |||
| } | |||
| // Pattern3====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, y, b, z) | |||
| // u3 = UpdateState(u2, t) | |||
| //==> | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, y, z) | |||
| // u3 = UpdateState(u2, t) | |||
| void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple, | |||
| const AnfNodePtr &load) { | |||
| auto &make_tuple_inputs = make_tuple->inputs(); | |||
| std::vector<AnfNodePtr> new_make_tuple_inputs; | |||
| (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs), | |||
| [load](const AnfNodePtr &input) { return load != input; }); | |||
| const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs); | |||
| new_make_tuple->set_abstract(make_tuple->abstract()); | |||
| manager->Replace(make_tuple, new_make_tuple); | |||
| } | |||
| void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { | |||
| auto load_users = manager->node_users()[load]; | |||
| for (const auto &load_user : load_users) { | |||
| // Pattern1 | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { | |||
| DeleteLoadUserUpdateState(manager, load_user.first, load); | |||
| continue; | |||
| } | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { | |||
| const auto &make_tuple = load_user.first->cast<CNodePtr>(); | |||
| auto &maketuple_users = manager->node_users()[make_tuple]; | |||
| auto maketuple_as_input_of_update = | |||
| maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState); | |||
| if (!maketuple_as_input_of_update) { | |||
| continue; | |||
| } | |||
| // Pattern2 | |||
| if (make_tuple->size() == 3) { | |||
| DeleteLoadUserMakeTuple(manager, make_tuple, load); | |||
| continue; | |||
| } | |||
| // Pattern3 | |||
| if (make_tuple->size() > 3) { | |||
| ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, | |||
| const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) { | |||
| if (group.size() <= 1) { | |||
| return false; | |||
| } | |||
| const auto &main = toposet[group[0]]; | |||
| for (size_t i = 1; i < group.size(); i++) { | |||
| ReplaceLoadUser(manager, fg, toposet[group[i]]); | |||
| manager->Replace(toposet[group[i]], main); | |||
| } | |||
| return true; | |||
| } | |||
| AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { | |||
| auto ¶ms = fg->parameters(); | |||
| auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend(); | |||
| auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); }); | |||
| if (iter != end) { | |||
| return *iter; | |||
| } | |||
| auto monad = NewValueNode(kUMonad); | |||
| monad->set_abstract(kUMonad->ToAbstract()); | |||
| return monad; | |||
| } | |||
| // Replace UpdateStates with U for first load. | |||
| // Covert: | |||
| // u1 = UpdateState(u, c) | |||
| // p1 = Load(para1, u1) // first load for para1 | |||
| // To: | |||
| // u1 = UpdateState(u, c) | |||
| // p1 = Load(para1, u') // u' is first monad in graph or new monad | |||
| bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) { | |||
| if (need_replace_loads.size() == 0) { | |||
| return false; | |||
| } | |||
| constexpr size_t second_input_index = 2; | |||
| auto monad = GetFirstMonad(fg); | |||
| for (const auto &load_node : need_replace_loads) { | |||
| if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) { | |||
| continue; | |||
| } | |||
| auto update_state = load_node->cast<CNodePtr>()->input(second_input_index); | |||
| if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) { | |||
| continue; | |||
| } | |||
| auto mgr = fg->manager(); | |||
| mgr->SetEdge(load_node, second_input_index, monad); | |||
| } | |||
| return true; | |||
| } | |||
| // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => | |||
| // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,... | |||
| bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { | |||
| auto changed = false; | |||
| for (const FuncGraphPtr &fg : manager->func_graphs()) { | |||
| std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); | |||
| std::vector<AnfNodePtr> need_replace_loads; | |||
| std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); | |||
| const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); | |||
| if (update_state_replaced) { | |||
| changed = true; | |||
| } | |||
| // split group if there is no-load node between two load nodes. | |||
| std::vector<std::vector<size_t>> need_merge_loads; | |||
| for (auto &group : load_groups) { | |||
| auto groups = SplitGroup(toposet, group); | |||
| need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); | |||
| } | |||
| for (auto &group : need_merge_loads) { | |||
| const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); | |||
| if (!changed && replaced) { | |||
| changed = true; | |||
| } | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| // The op like print, summary, or the op do not has true output, and always as a depend node input. | |||
| static bool HasSideEffect(const AnfNodePtr &node) { | |||
| auto prim = GetCNodePrimitive(node); | |||
| @@ -507,9 +256,7 @@ bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::si | |||
| bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| auto change1 = ReplaceAutoMonadNode(manager); | |||
| auto change2 = BuildOrderGroupAndDoReplace(manager); | |||
| return change1 || change2; | |||
| return BuildOrderGroupAndDoReplace(manager); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,7 +42,6 @@ class CSE { | |||
| private: | |||
| bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; | |||
| bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const; | |||
| bool DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group, | |||
| std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const; | |||
| }; | |||
| @@ -33,6 +33,7 @@ | |||
| #include "frontend/optimizer/clean.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/graph_transform.h" | |||
| #include "frontend/optimizer/auto_monad_eliminate.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| #include "frontend/parallel/step_auto_parallel.h" | |||
| #include "frontend/parallel/cache_embedding/cache_embedding.h" | |||
| @@ -183,6 +184,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| {"a_after_grad", a_after_grad}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)}, | |||
| {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())}, | |||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||
| {"a_3", a_3}}); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/operator/composite/multitype_funcgraph.h" | |||
| #include "utils/flags.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/ordered_map.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| @@ -1291,6 +1292,14 @@ class AutoMonadConverter { | |||
| } | |||
| AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) { | |||
| // Not attach UpdateState if set kAttrIgnoreSideEffect. | |||
| auto attr_ignore_side_effect = attach->cast<CNodePtr>()->GetAttr(kAttrIgnoreSideEffect); | |||
| auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() && | |||
| GetValue<bool>(attr_ignore_side_effect); | |||
| if (ignore_side_effect) { | |||
| return state; | |||
| } | |||
| auto update_state = NewValueNode(prim::kPrimUpdateState); | |||
| auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach}); | |||
| update_state_cnode->set_abstract(state->abstract()); | |||
| @@ -407,6 +407,7 @@ constexpr auto kAttrParallelTypeInfo = "parallel_type_info"; | |||
| constexpr auto kAttrCompositeType = "composite_type"; | |||
| constexpr auto kAttrStitch = "stitch"; | |||
| constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; | |||
| constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect"; | |||
| constexpr auto kAttrSwitchLayer = "switch_layer"; | |||
| constexpr auto kAttrReturn = "return"; | |||