You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel_reuse.cc 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/graph_kernel_reuse.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <string>
  20. #include "ir/graph_utils.h"
  21. namespace mindspore {
  22. /* namespace to support opt */
  23. namespace opt {
  24. bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) {
  25. if (a->abstract() && b->abstract()) {
  26. auto a_type = a->abstract()->GetTypeTrack();
  27. auto b_type = b->abstract()->GetTypeTrack();
  28. if (a_type != b_type) {
  29. return false;
  30. }
  31. auto a_shape = a->abstract()->GetShapeTrack();
  32. auto b_shape = b->abstract()->GetShapeTrack();
  33. if (a_shape != nullptr && a_shape == b_shape) {
  34. return true;
  35. }
  36. if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() &&
  37. b_shape->isa<abstract::Shape>()) {
  38. return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape();
  39. }
  40. }
  41. return false;
  42. }
  43. bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) {
  44. bool changed = false;
  45. auto fgs = manager->func_graphs();
  46. for (FuncGraphPtr &fg : fgs) {
  47. if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  48. continue;
  49. }
  50. std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  51. if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) {
  52. if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) {
  53. FuncGraphPtr new_fg = nullptr;
  54. for (auto &cfg : graph_kernel_ops[key]) {
  55. // If two graphs have different size then continue
  56. auto fg_topos = TopoSort(fg->get_return());
  57. auto cfg_topos = TopoSort(cfg->get_return());
  58. if (fg_topos.size() != cfg_topos.size()) {
  59. continue;
  60. }
  61. // Compare const tensor
  62. bool has_same = true;
  63. for (size_t i = 0; i < fg_topos.size(); ++i) {
  64. if (IsValueNode<tensor::Tensor>(fg_topos[i])) {
  65. if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) {
  66. has_same = false;
  67. break;
  68. }
  69. auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]);
  70. auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]);
  71. if (!tensor1->ValueEqual(*tensor2)) {
  72. has_same = false;
  73. break;
  74. }
  75. }
  76. }
  77. if (!has_same) {
  78. continue;
  79. }
  80. auto fg_input = fg->parameters();
  81. auto cfg_input = cfg->parameters();
  82. if (fg_input.size() != cfg_input.size()) {
  83. continue;
  84. }
  85. // Compare input
  86. for (size_t i = 0; i < fg_input.size(); ++i) {
  87. if (!CompareNode(fg_input[i], cfg_input[i])) {
  88. has_same = false;
  89. break;
  90. }
  91. }
  92. if (!has_same) {
  93. continue;
  94. }
  95. // Compare output
  96. if (!CompareNode(fg->output(), cfg->output())) {
  97. continue;
  98. }
  99. // Find reusable fg
  100. new_fg = cfg;
  101. break;
  102. }
  103. if (new_fg != nullptr) {
  104. // Replace current fg with existing fg
  105. auto users = fg->func_graph_cnodes_index();
  106. for (auto &iter : users) {
  107. auto cnode = iter.first->first->cast<CNodePtr>();
  108. auto new_input = cnode->inputs();
  109. auto main_graph = cnode->func_graph();
  110. MS_EXCEPTION_IF_NULL(main_graph);
  111. if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
  112. new_input[1] = NewValueNode(new_fg);
  113. } else {
  114. new_input[0] = NewValueNode(new_fg);
  115. }
  116. auto new_cnode = main_graph->NewCNode(new_input);
  117. manager->Replace(iter.first->first, new_cnode);
  118. changed = true;
  119. }
  120. } else {
  121. // Add current fg to map
  122. graph_kernel_ops[key].push_back(fg);
  123. }
  124. }
  125. } else {
  126. graph_kernel_ops[key] = {fg};
  127. }
  128. }
  129. return changed;
  130. }
  131. bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
  132. MS_EXCEPTION_IF_NULL(manager);
  133. manager->AddFuncGraph(root);
  134. return DoReplace(manager);
  135. }
  136. } // namespace opt
  137. } // namespace mindspore