You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_var_prepare.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "optimizer/irpass/grad_var_prepare.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include "operator/composite/composite.h"
  22. #include "operator/ops.h"
  23. #include "optimizer/irpass.h"
  24. #include "optimizer/optimizer.h"
  25. #include "ir/visitor.h"
  26. #include "ir/func_graph.h"
  27. #include "ir/func_graph_cloner.h"
  28. namespace mindspore {
  29. namespace opt {
  30. namespace irpass {
  31. static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph,
  32. AnfNodePtr func_node, bool is_unpack, bool sens_param) {
  33. MS_EXCEPTION_IF_NULL(func_graph);
  34. MS_EXCEPTION_IF_NULL(func_node);
  35. std::vector<AnfNodePtr> nodes;
  36. AnfNodePtr unpack_graph_node = nullptr;
  37. if (is_unpack) {
  38. auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, true);
  39. nodes.push_back(NewValueNode(unpack_graph));
  40. nodes.push_back(func_node);
  41. // {unpackcall, {GradOperation, ...}, args...}
  42. std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
  43. [](const AnfNodePtr &node) { return node; });
  44. unpack_graph_node = func_graph->NewCNode(nodes);
  45. } else {
  46. auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
  47. nodes.push_back(NewValueNode(unpack_graph));
  48. nodes.push_back(func_node);
  49. // {{GradOperation, ...}, args...}
  50. std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
  51. [](const AnfNodePtr &node) { return node; });
  52. unpack_graph_node = func_graph->NewCNode(nodes);
  53. }
  54. return unpack_graph_node;
  55. }
  56. // get metagraph of value node
  57. MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
  58. ValuePtr value;
  59. if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
  60. value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
  61. } else {
  62. value = GetValueNode(node);
  63. }
  64. if (value == nullptr) {
  65. return nullptr;
  66. }
  67. return value->cast<MetaFuncGraphPtr>();
  68. }
  69. // check if node is a specific metafuncgraph op
  70. bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) {
  71. if (node != nullptr) {
  72. auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
  73. if (meta_func_graph_ptr == nullptr) {
  74. return false;
  75. }
  76. if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) {
  77. return true;
  78. }
  79. }
  80. return false;
  81. }
  82. // {{GradOperation, g, w}, Ys}
  83. // {UnPackCall, {GradOperation, g, w}, Ys}
  84. AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  85. if (!node->isa<CNode>() || node->func_graph() == nullptr) {
  86. return nullptr;
  87. }
  88. // {{...}, Ys}
  89. auto inputs_y = node->cast<CNodePtr>()->inputs();
  90. std::vector<AnfNodePtr> inputs_x;
  91. if (IsCNode(inputs_y[0])) {
  92. inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs();
  93. } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) {
  94. inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs();
  95. } else {
  96. return nullptr;
  97. }
  98. // {{...}, Xs}
  99. if (inputs_x.size() < 2) {
  100. return nullptr;
  101. }
  102. // {GradOperation, g, w} or {GradOperation, g}
  103. if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) {
  104. return nullptr;
  105. }
  106. auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
  107. if (meta_func == nullptr) {
  108. return nullptr;
  109. }
  110. auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
  111. auto func_node = inputs_x[1];
  112. if (!IsValueNode<FuncGraph>(func_node)) {
  113. return nullptr;
  114. }
  115. AnfNodePtr unpack_graph_node =
  116. GenerateUnpackGraphNode(inputs_y, node->cast<CNodePtr>()->func_graph(), func_node,
  117. IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param());
  118. // constuct new grad_opration
  119. inputs_x[1] = unpack_graph_node;
  120. auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x);
  121. if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) {
  122. inputs_y[1] = grad_op_cnode;
  123. } else {
  124. inputs_y[0] = grad_op_cnode;
  125. }
  126. auto cnode = node->func_graph()->NewCNode(inputs_y);
  127. return cnode;
  128. }
  129. } // namespace irpass
  130. } // namespace opt
  131. } // namespace mindspore