You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam_fusion.cc 7.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/gpu/adam_fusion.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include "backend/session/anf_runtime_algorithm.h"
  21. #include "ir/primitive.h"
  22. #include "utils/utils.h"
  23. #include "backend/optimizer/common/helper.h"
  24. namespace mindspore {
  25. namespace opt {
  26. namespace {
  27. kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
  28. std::vector<std::string> inputs_format;
  29. std::vector<std::string> outputs_format;
  30. std::vector<TypeId> inputs_type;
  31. std::vector<TypeId> outputs_type;
  32. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  33. MS_EXCEPTION_IF_NULL(node);
  34. size_t input_num = AnfAlgo::GetInputTensorNum(node);
  35. for (size_t input_index = 0; input_index < input_num; ++input_index) {
  36. inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
  37. inputs_format.push_back(kOpFormat_DEFAULT);
  38. }
  39. size_t output_num = AnfAlgo::GetOutputTensorNum(node);
  40. for (size_t output_index = 0; output_index < output_num; ++output_index) {
  41. outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
  42. outputs_format.push_back(kOpFormat_DEFAULT);
  43. }
  44. builder.SetInputsDeviceType(inputs_type);
  45. builder.SetInputsFormat(inputs_format);
  46. builder.SetOutputsDeviceType(outputs_type);
  47. builder.SetOutputsFormat(outputs_format);
  48. return builder.Build();
  49. }
  50. AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u_input) {
  51. // Replace the parameters of the last UpdateState to maintain
  52. // the execution order of FusedAdam and the following operators.
  53. // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
  54. const size_t assign_index = 2;
  55. auto cnode = node->cast<CNodePtr>();
  56. MS_EXCEPTION_IF_NULL(cnode);
  57. const auto &n = cnode->input(assign_index);
  58. MS_EXCEPTION_IF_NULL(n);
  59. const auto &fg = n->func_graph();
  60. MS_EXCEPTION_IF_NULL(fg);
  61. auto mgr = fg->manager();
  62. MS_EXCEPTION_IF_NULL(mgr);
  63. auto &node_users = mgr->node_users();
  64. auto iter = node_users.find(n);
  65. if (iter == node_users.end()) {
  66. MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
  67. }
  68. auto &users = iter->second;
  69. for (auto &user : users) {
  70. if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
  71. const size_t monad_index = 1;
  72. const size_t adam_index = 2;
  73. auto cnode_ptr = (user.first)->cast<CNodePtr>();
  74. MS_EXCEPTION_IF_NULL(cnode_ptr);
  75. cnode_ptr->set_input(monad_index, u_input);
  76. cnode_ptr->set_input(adam_index, adam);
  77. break;
  78. }
  79. }
  80. return adam;
  81. }
  82. } // namespace
  83. const BaseRef AdamFusion::DefinePattern() const {
  84. VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_});
  85. VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
  86. VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
  87. VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
  88. VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
  89. VectorRef next_v =
  90. VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
  91. VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
  92. VectorRef update =
  93. VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
  94. VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
  95. VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr});
  96. VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v});
  97. VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load});
  98. VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state});
  99. next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param});
  100. next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});
  101. VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});
  102. next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m});
  103. next_param = VectorRef({prim::kPrimDepend, next_param, assign_m});
  104. VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state});
  105. next_param = VectorRef({prim::kPrimDepend, next_param, assign_v});
  106. return next_param;
  107. }
  108. const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
  109. MS_EXCEPTION_IF_NULL(graph);
  110. MS_EXCEPTION_IF_NULL(node);
  111. MS_EXCEPTION_IF_NULL(equiv);
  112. auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
  113. auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
  114. auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
  115. auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
  116. auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
  117. auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
  118. auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
  119. auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
  120. auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
  121. auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
  122. auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]);
  123. MS_EXCEPTION_IF_NULL(beta1_input);
  124. MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
  125. MS_EXCEPTION_IF_NULL(beta2_input);
  126. MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
  127. MS_EXCEPTION_IF_NULL(eps_input);
  128. MS_EXCEPTION_IF_NULL(lr_input);
  129. MS_EXCEPTION_IF_NULL(param_input);
  130. MS_EXCEPTION_IF_NULL(m_input);
  131. MS_EXCEPTION_IF_NULL(v_input);
  132. MS_EXCEPTION_IF_NULL(gradient_input);
  133. MS_EXCEPTION_IF_NULL(u_input);
  134. // Use depend(param, u) to maintain the execution order of FusedAdam and the previous operators.
  135. auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name());
  136. MS_EXCEPTION_IF_NULL(prim_depend);
  137. std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input};
  138. auto param = graph->NewCNode(param_inputs);
  139. MS_EXCEPTION_IF_NULL(param);
  140. param->set_abstract(param_input->abstract());
  141. // Fused into a FusedAdam operator.
  142. auto prim = std::make_shared<Primitive>(kFusedAdamName);
  143. MS_EXCEPTION_IF_NULL(prim);
  144. auto prim_value = NewValueNode(prim);
  145. std::vector<AnfNodePtr> inputs = {
  146. prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
  147. m_input, v_input, gradient_input};
  148. auto adam = graph->NewCNode(inputs);
  149. MS_EXCEPTION_IF_NULL(adam);
  150. auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
  151. auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
  152. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
  153. adam->set_scope(node->scope());
  154. auto build_info = GenerateKernelBuildInfo(adam);
  155. AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
  156. return RelpaceOutputEdge(node, adam, u_input);
  157. }
  158. } // namespace opt
  159. } // namespace mindspore