| @@ -20,7 +20,6 @@ | |||
| #include <string> | |||
| #include "base/core_ops.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -53,6 +53,7 @@ | |||
| #include "backend/optimizer/pass/communication_op_fusion.h" | |||
| #include "backend/optimizer/gpu/concat_outputs_for_all_gather.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/pass/optimize_updatestate.h" | |||
| #include "common/trans.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/data_dump/e2e_dump.h" | |||
| @@ -184,6 +185,8 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); | |||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | |||
| pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); | |||
| // Remove node only used by UpdateState, in order to ensure the correct execution sequence in CudnnInplaceAggregate. | |||
| pm->AddPass(std::make_shared<opt::OptimizeUpdateState>()); | |||
| pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>()); | |||
| pm->AddPass(std::make_shared<opt::ReluV2Pass>()); | |||
| pm->AddPass(std::make_shared<opt::AddReluV2Fusion>()); | |||
| @@ -628,7 +628,6 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||
| MS_EXCEPTION_IF_NULL(cnode_inputs); | |||
| auto origin_inputs = cnode->inputs(); | |||
| const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend); | |||
| const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState); | |||
| // if has multiple depends,only select first depend as parameter | |||
| for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | |||
| auto anf = origin_inputs[input_idx]; | |||
| @@ -637,8 +636,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||
| continue; | |||
| } else if ((is_depend && input_idx > kRealInputIndexInDepend) || | |||
| (is_updatestate && input_idx > kUpdateStateRealInput)) { | |||
| } else if ((is_depend && input_idx > kRealInputIndexInDepend)) { | |||
| cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); | |||
| continue; | |||
| } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { | |||
| @@ -79,6 +79,31 @@ class OrderEnforcer { | |||
| return abs != nullptr && abs->isa<abstract::AbstractRef>(); | |||
| } | |||
| // Find Load or parameter users as the candidate nodes to enforce order of execution. | |||
| std::unordered_set<AnfNodePtr> GetSpecialOperatorRealUsers(const AnfNodePtr &node) { | |||
| auto &node_users = manager_->node_users(); | |||
| auto iter = node_users.find(node); | |||
| if (iter == node_users.end()) { | |||
| return {}; | |||
| } | |||
| std::unordered_set<AnfNodePtr> real_users; | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| auto &user_node = user.first; | |||
| real_users.insert(user_node); | |||
| } | |||
| return real_users; | |||
| } | |||
| bool IsOneOfPrimitive(const AnfNodePtr &node, const std::set<PrimitivePtr> &special_node_types) { | |||
| for (const auto &type : special_node_types) { | |||
| if (IsPrimitiveCNode(node, type)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void EnforceOrderForOtherCNode(const CNodePtr &cnode) { | |||
| // Find refs from the cnode inputs. | |||
| auto &inputs = cnode->inputs(); | |||
| @@ -87,6 +112,7 @@ class OrderEnforcer { | |||
| if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) { | |||
| return; | |||
| } | |||
| const std::set<PrimitivePtr> special_operators = {prim::kPrimExpandDims}; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto &input = inputs.at(i); | |||
| if (!IsRef(input)) { | |||
| @@ -96,7 +122,17 @@ class OrderEnforcer { | |||
| auto loads = FindLoadUsers(input); | |||
| for (auto load : loads) { | |||
| std::unordered_set<AnfNodePtr> load_users = FindUsers(load); | |||
| AddInputEdges(last_input->cast<CNodePtr>(), load_users); | |||
| std::unordered_set<AnfNodePtr> real_users; | |||
| for (auto load_user : load_users) { | |||
| // check the special operator, only one level of user is considered for now | |||
| if (IsOneOfPrimitive(load_user, special_operators)) { | |||
| std::unordered_set<AnfNodePtr> special_real_users = GetSpecialOperatorRealUsers(load_user); | |||
| real_users.insert(special_real_users.begin(), special_real_users.end()); | |||
| } else { | |||
| real_users.insert(load_user); | |||
| } | |||
| } | |||
| AddInputEdges(last_input->cast<CNodePtr>(), real_users); | |||
| } | |||
| } | |||
| } | |||
| @@ -126,7 +162,10 @@ class OrderEnforcer { | |||
| void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) { | |||
| auto sorted_load_users = SortLoadUsers(load_users); | |||
| for (auto &load_user : sorted_load_users) { | |||
| if (!IsDependOn(load_user, update_state) && !IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { | |||
| if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { | |||
| continue; | |||
| } | |||
| if (!IsDependOn(load_user, update_state)) { | |||
| processed_nodes_.insert(load_user); | |||
| if (!IsInUpdateState(load_user, update_state)) { | |||
| manager_->AddEdge(update_state, load_user); | |||
| @@ -225,7 +264,6 @@ class OrderEnforcer { | |||
| return loads; | |||
| } | |||
| private: | |||
| const FuncGraphPtr &func_graph_; | |||
| FuncGraphManagerPtr manager_; | |||
| std::unordered_map<AnfNodePtr, size_t> topo_sort_map_; | |||
| @@ -38,6 +38,7 @@ | |||
| #include "debug/rdr/running_data_recorder.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "debug/debugger/debugger.h" | |||
| #include "backend/optimizer/pass/optimize_updatestate.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -251,6 +252,8 @@ void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) | |||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | |||
| pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| // Remove node only used by UpdateState, in order to ensure the correct execution sequence in CudnnInplaceAggregate. | |||
| pm->AddPass(std::make_shared<opt::OptimizeUpdateState>()); | |||
| pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>()); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ReluV2Pass>()); | |||
| @@ -59,11 +59,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, | |||
| continue; | |||
| } | |||
| auto &node_users = iter->second; | |||
| const bool has_outer_user = std::any_of( | |||
| std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool { | |||
| const bool is_outer_user = (seen.find(u.first) == seen.end()); | |||
| return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2); | |||
| }); | |||
| const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users), | |||
| [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool { | |||
| const bool is_outer_user = (seen.find(u.first) == seen.end()); | |||
| return is_outer_user; | |||
| }); | |||
| if (has_outer_user) { | |||
| output.emplace_back(node); | |||
| } | |||
| @@ -127,16 +127,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| for (size_t i = value_start_index; i < inps.size(); ++i) { | |||
| args.emplace_back(NewValueNode(MakeValue(0))); | |||
| } | |||
| } else if (IsPrimitive(fn, prim::kPrimUpdateState)) { | |||
| args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv)); | |||
| args.emplace_back(RefSubGraphNode(fg, inps[kUpdateStateRealInput], &inputs, &eqv)); | |||
| const size_t additional_input_index = 3; | |||
| for (size_t i = additional_input_index; i < inps.size(); ++i) { | |||
| auto &input = inps[i]; | |||
| if (eqv.find(input) != eqv.end()) { | |||
| args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv)); | |||
| } | |||
| } | |||
| } else { | |||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), | |||
| [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); | |||
| @@ -0,0 +1,83 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import pytest | |||
| from mindspore.nn import Cell | |||
| from mindspore import context, Tensor, Parameter | |||
| import mindspore.ops.operations as P | |||
| import mindspore as ms | |||
| import numpy as np | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class AutoMonadAddnAdamNet(Cell): | |||
| def __init__(self, var, m, v): | |||
| super().__init__() | |||
| self.apply_adam = P.Adam() | |||
| self.var = Parameter(var, name="var") | |||
| self.m = Parameter(m, name="m") | |||
| self.v = Parameter(v, name="v") | |||
| self.addn = P.AddN() | |||
| self.mul = P.Mul() | |||
| def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): | |||
| out = self.addn((self.var, self.m, self.v)) | |||
| self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) | |||
| return out, self.var, self.m, self.v | |||
| def _count_unequal_element(data_expected, data_me, rtol, atol): | |||
| assert data_expected.shape == data_me.shape | |||
| total_count = len(data_expected.flatten()) | |||
| error = np.abs(data_expected - data_me) | |||
| greater = np.greater(error, atol + np.abs(data_me) * rtol) | |||
| loss_count = np.count_nonzero(greater) | |||
| assert (loss_count / total_count) < rtol, \ | |||
| "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ | |||
| format(data_expected[greater], data_me[greater], error[greater]) | |||
| def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): | |||
| if np.any(np.isnan(data_expected)): | |||
| assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) | |||
| elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): | |||
| _count_unequal_element(data_expected, data_me, rtol, atol) | |||
| else: | |||
| assert True | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_auto_monad_addn_adam(): | |||
| var = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||
| m = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||
| v = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||
| net = AutoMonadAddnAdamNet(var, m, v) | |||
| beta1_power = Tensor(0.9, ms.float32) | |||
| beta2_power = Tensor(0.999, ms.float32) | |||
| lr = Tensor(0.1, ms.float32) | |||
| beta1 = Tensor(0.9, ms.float32) | |||
| beta2 = Tensor(0.999, ms.float32) | |||
| epsilon = Tensor(1e-8, ms.float32) | |||
| grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||
| out, new_var, new_m, new_v = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) | |||
| net = AutoMonadAddnAdamNet(var, m, v) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| out_pyn, new_var_pyn, new_m_pyn, new_v_pyn = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) | |||
| allclose_nparray(out_pyn.asnumpy(), out.asnumpy(), 0.001, 0.001) | |||
| allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001) | |||
| allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001) | |||
| allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001) | |||