/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "backend/optimizer/gpu/adam_fusion.h" #include #include #include #include "backend/session/anf_runtime_algorithm.h" #include "ir/primitive.h" #include "utils/utils.h" #include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { namespace { kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { std::vector inputs_format; std::vector outputs_format; std::vector inputs_type; std::vector outputs_type; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; size_t input_num = AnfAlgo::GetInputTensorNum(node); for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); inputs_format.push_back(kOpFormat_DEFAULT); } size_t output_num = AnfAlgo::GetOutputTensorNum(node); for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); outputs_format.push_back(kOpFormat_DEFAULT); } builder.SetInputsDeviceType(inputs_type); builder.SetInputsFormat(inputs_format); builder.SetOutputsDeviceType(outputs_type); builder.SetOutputsFormat(outputs_format); return builder.Build(); } } // namespace const BaseRef AdamFusion::DefinePattern() const { VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); VectorRef next_v = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); VectorRef update = VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m}); next_param = VectorRef({prim::kPrimDepend, next_param, assign_m}); VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state}); next_param = VectorRef({prim::kPrimDepend, next_param, assign_v}); return next_param; } const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); auto beta1_input = utils::cast((*equiv)[beta1_]); auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); auto beta2_input = utils::cast((*equiv)[beta2_]); auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); auto eps_input = utils::cast((*equiv)[eps_]); auto lr_input = utils::cast((*equiv)[lr_]); auto param_input = utils::cast((*equiv)[param_]); auto m_input = utils::cast((*equiv)[m_]); auto v_input = utils::cast((*equiv)[v_]); auto gradient_input = utils::cast((*equiv)[gradient_]); auto u_input = utils::cast((*equiv)[u_]); MS_EXCEPTION_IF_NULL(beta1_input); MS_EXCEPTION_IF_NULL(one_sub_beta1_input); MS_EXCEPTION_IF_NULL(beta2_input); MS_EXCEPTION_IF_NULL(one_sub_beta2_input); MS_EXCEPTION_IF_NULL(eps_input); MS_EXCEPTION_IF_NULL(lr_input); MS_EXCEPTION_IF_NULL(param_input); MS_EXCEPTION_IF_NULL(m_input); MS_EXCEPTION_IF_NULL(v_input); MS_EXCEPTION_IF_NULL(gradient_input); MS_EXCEPTION_IF_NULL(u_input); // Use depend(param, u) to maintain the execution order of FusedAdam and the previous operators. auto prim_depend = std::make_shared(prim::kPrimDepend->name()); MS_EXCEPTION_IF_NULL(prim_depend); std::vector param_inputs = {NewValueNode(prim_depend), param_input, u_input}; auto param = graph->NewCNode(param_inputs); MS_EXCEPTION_IF_NULL(param); param->set_abstract(param_input->abstract()); // Fused into a FusedAdam operator. auto prim = std::make_shared(kFusedAdamName); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param, m_input, v_input, gradient_input}; auto adam = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(adam); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); adam->set_scope(node->scope()); auto build_info = GenerateKernelBuildInfo(adam); AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); // Replace the parameters of the last UpdateState to maintain // the execution order of FusedAdam and the following operators. // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} auto n = node->cast()->input(2); auto fg = n->func_graph(); MS_EXCEPTION_IF_NULL(fg); auto mgr = fg->manager(); MS_EXCEPTION_IF_NULL(mgr); auto &node_users = mgr->node_users(); auto iter = node_users.find(n); if (iter == node_users.end()) { MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); } auto &users = iter->second; for (auto &user : users) { if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { (user.first)->cast()->set_input(1, u_input); (user.first)->cast()->set_input(2, adam); break; } } return adam; } } // namespace opt } // namespace mindspore