You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_dependence.cc 8.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/pass/optimize_dependence.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include <utility>
  21. #include "backend/optimizer/common/helper.h"
  22. #include "base/core_ops.h"
  23. #include "utils/utils.h"
  24. #include "backend/session/kernel_graph.h"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. namespace mindspore {
  27. namespace opt {
  28. constexpr auto kSingleInputIndex = 1;
  29. constexpr auto kIsolatedDependRealInputIndex = 0;
  30. constexpr auto kIsolatedDependVirtualInputIndex = 1;
  31. namespace {
  32. CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
  33. const std::vector<AnfNodePtr> &new_depend_inputs) {
  34. MS_EXCEPTION_IF_NULL(func_graph);
  35. MS_EXCEPTION_IF_NULL(cnode);
  36. auto kernel_graph = func_graph->cast<KernelGraphPtr>();
  37. if (kernel_graph == nullptr) {
  38. auto new_depend = func_graph->NewCNode(new_depend_inputs);
  39. MS_EXCEPTION_IF_NULL(new_depend);
  40. new_depend->set_abstract(cnode->abstract());
  41. new_depend->set_scope(cnode->scope());
  42. return new_depend;
  43. }
  44. auto new_depend = kernel_graph->NewCNode(cnode);
  45. MS_EXCEPTION_IF_NULL(new_depend);
  46. new_depend->set_inputs(new_depend_inputs);
  47. return new_depend;
  48. }
  49. CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
  50. MS_EXCEPTION_IF_NULL(cnode);
  51. if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() &&
  52. AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) {
  53. return nullptr;
  54. }
  55. auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
  56. if (!HasAbstractMonad(virtual_input_op)) {
  57. return nullptr;
  58. }
  59. auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
  60. MS_EXCEPTION_IF_NULL(real_input_op);
  61. if (!real_input_op->isa<CNode>()) {
  62. return nullptr;
  63. }
  64. auto real_input_cnode = real_input_op->cast<CNodePtr>();
  65. return real_input_cnode;
  66. }
  67. AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
  68. const CNodePtr &eliminate_node) {
  69. MS_EXCEPTION_IF_NULL(func_graph);
  70. MS_EXCEPTION_IF_NULL(cnode);
  71. MS_EXCEPTION_IF_NULL(eliminate_node);
  72. auto replace_node = eliminate_node->input(kSingleInputIndex);
  73. std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
  74. new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
  75. auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
  76. (void)func_graph->manager()->Replace(cnode, new_depend);
  77. return new_depend;
  78. }
  79. AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  80. MS_EXCEPTION_IF_NULL(func_graph);
  81. MS_EXCEPTION_IF_NULL(node);
  82. if (!node->isa<CNode>()) {
  83. return nullptr;
  84. }
  85. auto cnode = node->cast<CNodePtr>();
  86. MS_EXCEPTION_IF_NULL(cnode);
  87. auto replace_cnode = cnode;
  88. // Process updatestate and depend as isolated node env.
  89. auto isolated_cnode = CheckIsolatedVirtualNode(replace_cnode);
  90. if (isolated_cnode != nullptr) {
  91. replace_cnode = isolated_cnode;
  92. }
  93. string op_name = AnfAlgo::GetCNodeName(replace_cnode);
  94. // Currently we only eliminate transdata or cast nodes.
  95. if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
  96. return nullptr;
  97. }
  98. if (!IsNotRealUsedByOthers(func_graph, replace_cnode)) {
  99. return nullptr;
  100. }
  101. CheckCNodeInputSize(replace_cnode, kSingleInputIndex);
  102. if (isolated_cnode != nullptr) {
  103. auto new_depend_node = EliminateIsolatedVirtualNodeInput(func_graph, cnode, replace_cnode);
  104. return new_depend_node;
  105. }
  106. return cnode->input(kSingleInputIndex);
  107. }
  108. AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  109. MS_EXCEPTION_IF_NULL(func_graph);
  110. MS_EXCEPTION_IF_NULL(cnode);
  111. if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
  112. return nullptr;
  113. }
  114. std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  115. bool need_update = false;
  116. size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
  117. for (size_t index = 0; index < input_num; ++index) {
  118. auto input = AnfAlgo::GetInputNode(cnode, index);
  119. AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
  120. // If replace input is not null, it will be the input of the TransData or Cast.
  121. if (replace_input == nullptr) {
  122. new_make_tuple_inputs.push_back(input);
  123. continue;
  124. }
  125. new_make_tuple_inputs.push_back(replace_input);
  126. need_update = true;
  127. }
  128. if (need_update) {
  129. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  130. CNodePtr new_make_tuple = nullptr;
  131. if (kernel_graph == nullptr) {
  132. new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
  133. } else {
  134. new_make_tuple = kernel_graph->NewCNode(cnode);
  135. }
  136. MS_EXCEPTION_IF_NULL(new_make_tuple);
  137. new_make_tuple->set_inputs(new_make_tuple_inputs);
  138. auto manager = func_graph->manager();
  139. MS_EXCEPTION_IF_NULL(manager);
  140. manager->Replace(cnode, new_make_tuple);
  141. return new_make_tuple;
  142. }
  143. return nullptr;
  144. }
  145. } // namespace
  146. const BaseRef OptimizeDependence::DefinePattern() const {
  147. VarPtr X = std::make_shared<Var>();
  148. VarPtr Xs = std::make_shared<SeqVar>();
  149. return VectorRef({X, Xs});
  150. }
  151. std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) {
  152. // Search Depend and UpdateState only.
  153. if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) {
  154. return {};
  155. }
  156. // Find inputs which is Cast or TransData.
  157. std::vector<size_t> result;
  158. for (size_t i = 1; i < cnode->size(); ++i) {
  159. auto &input = cnode->input(i);
  160. if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
  161. AnfAlgo::CheckPrimitiveType(input, prim::kPrimTransData) ||
  162. AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
  163. (void)result.emplace_back(i);
  164. }
  165. }
  166. return result;
  167. }
  168. const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  169. const EquivPtr &) const {
  170. MS_EXCEPTION_IF_NULL(func_graph);
  171. MS_EXCEPTION_IF_NULL(node);
  172. auto cnode = dyn_cast<CNode>(node);
  173. if (cnode == nullptr) {
  174. return nullptr;
  175. }
  176. // Search inputs to be replaced.
  177. auto candidate_inputs = SearchTransDataAndCast(cnode);
  178. if (candidate_inputs.empty()) {
  179. return nullptr;
  180. }
  181. // Get new nodes which will act as new inputs of Depend or UpdateState.
  182. std::vector<AnfNodePtr> new_inputs = cnode->inputs();
  183. bool inputs_changed = false;
  184. for (auto index : candidate_inputs) {
  185. if (index >= new_inputs.size()) {
  186. MS_LOG(EXCEPTION) << "Index is out of the size of cnode inputs.";
  187. }
  188. auto replace_node = GetConvertNode(func_graph, cnode, index);
  189. if (replace_node != nullptr) {
  190. new_inputs[index] = replace_node;
  191. inputs_changed = true;
  192. }
  193. }
  194. if (!inputs_changed) {
  195. return nullptr;
  196. }
  197. // Create a new Depend node to replace the old one if inputs changed.
  198. auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs);
  199. (void)func_graph->manager()->Replace(cnode, new_depend);
  200. return nullptr;
  201. }
  202. const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
  203. const size_t index) const {
  204. MS_EXCEPTION_IF_NULL(graph);
  205. MS_EXCEPTION_IF_NULL(node);
  206. auto depend_cnode = node->cast<CNodePtr>();
  207. MS_EXCEPTION_IF_NULL(depend_cnode);
  208. auto replacing_node = depend_cnode->input(index);
  209. MS_EXCEPTION_IF_NULL(replacing_node);
  210. if (!replacing_node->isa<CNode>()) {
  211. return nullptr;
  212. }
  213. auto replacing_cnode = replacing_node->cast<CNodePtr>();
  214. MS_EXCEPTION_IF_NULL(replacing_cnode);
  215. // Deal with the make_tuple with TransData or Cast inputs.
  216. auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
  217. if (make_tuple_replace_node != nullptr) {
  218. return make_tuple_replace_node;
  219. }
  220. AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode);
  221. return replace_node;
  222. }
  223. } // namespace opt
  224. } // namespace mindspore