You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam_fusion.cc 5.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/gpu/adam_fusion.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include "backend/session/anf_runtime_algorithm.h"
  21. #include "ir/primitive.h"
  22. #include "utils/utils.h"
  23. #include "backend/optimizer/common/helper.h"
  24. namespace mindspore {
  25. namespace opt {
  26. namespace {
  27. kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
  28. std::vector<std::string> inputs_format;
  29. std::vector<std::string> outputs_format;
  30. std::vector<TypeId> inputs_type;
  31. std::vector<TypeId> outputs_type;
  32. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  33. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
  34. inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
  35. inputs_format.push_back(kOpFormat_DEFAULT);
  36. }
  37. for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
  38. outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
  39. outputs_format.push_back(kOpFormat_DEFAULT);
  40. }
  41. builder.SetInputsDeviceType(inputs_type);
  42. builder.SetInputsFormat(inputs_format);
  43. builder.SetOutputsDeviceType(outputs_type);
  44. builder.SetOutputsFormat(outputs_format);
  45. return builder.Build();
  46. }
  47. } // namespace
  48. const BaseRef AdamFusion::DefinePattern() const {
  49. VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
  50. VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
  51. VectorRef next_v =
  52. VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
  53. VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
  54. VectorRef update = VectorRef(
  55. {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
  56. VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
  57. VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
  58. next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
  59. next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
  60. next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
  61. return next_param;
  62. }
  63. const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
  64. MS_EXCEPTION_IF_NULL(graph);
  65. MS_EXCEPTION_IF_NULL(node);
  66. MS_EXCEPTION_IF_NULL(equiv);
  67. auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
  68. auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
  69. auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
  70. auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
  71. auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
  72. auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
  73. auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
  74. auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
  75. auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
  76. auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
  77. MS_EXCEPTION_IF_NULL(beta1_input);
  78. MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
  79. MS_EXCEPTION_IF_NULL(beta2_input);
  80. MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
  81. MS_EXCEPTION_IF_NULL(eps_input);
  82. MS_EXCEPTION_IF_NULL(lr_input);
  83. MS_EXCEPTION_IF_NULL(param_input);
  84. MS_EXCEPTION_IF_NULL(m_input);
  85. MS_EXCEPTION_IF_NULL(v_input);
  86. MS_EXCEPTION_IF_NULL(gradient_input);
  87. auto prim = std::make_shared<Primitive>(kFusedAdamName);
  88. MS_EXCEPTION_IF_NULL(prim);
  89. std::vector<AnfNodePtr> inputs = {
  90. NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
  91. eps_input, lr_input, param_input, m_input, v_input,
  92. gradient_input};
  93. auto adam = graph->NewCNode(inputs);
  94. MS_EXCEPTION_IF_NULL(adam);
  95. auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
  96. auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
  97. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
  98. adam->set_scope(node->scope());
  99. auto build_info = GenerateKernelBuildInfo(adam);
  100. AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
  101. return adam;
  102. }
  103. } // namespace opt
  104. } // namespace mindspore