You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel_reuse.cc 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "optimizer/graph_kernel_reuse.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <string>
  20. #include "./common.h"
  21. #include "utils/graph_utils.h"
  22. namespace mindspore {
  23. /* namespace to support opt */
  24. namespace opt {
  25. bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) {
  26. if (a->abstract() && b->abstract()) {
  27. auto a_type = a->abstract()->GetTypeTrack();
  28. auto b_type = b->abstract()->GetTypeTrack();
  29. if (a_type != b_type) {
  30. return false;
  31. }
  32. auto a_shape = a->abstract()->GetShapeTrack();
  33. auto b_shape = b->abstract()->GetShapeTrack();
  34. if (a_shape != nullptr && a_shape == b_shape) {
  35. return true;
  36. }
  37. if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() &&
  38. b_shape->isa<abstract::Shape>()) {
  39. return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape();
  40. }
  41. }
  42. return false;
  43. }
  44. bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) {
  45. bool changed = false;
  46. auto fgs = manager->func_graphs();
  47. for (FuncGraphPtr &fg : fgs) {
  48. if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  49. continue;
  50. }
  51. std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  52. if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) {
  53. if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) {
  54. FuncGraphPtr new_fg = nullptr;
  55. for (auto &cfg : graph_kernel_ops[key]) {
  56. // If two graphs have different size then continue
  57. auto fg_topos = TopoSort(fg->get_return());
  58. auto cfg_topos = TopoSort(cfg->get_return());
  59. if (fg_topos.size() != cfg_topos.size()) {
  60. continue;
  61. }
  62. // Compare const tensor
  63. bool has_same = true;
  64. for (size_t i = 0; i < fg_topos.size(); ++i) {
  65. if (IsValueNode<tensor::Tensor>(fg_topos[i])) {
  66. if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) {
  67. has_same = false;
  68. break;
  69. }
  70. auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]);
  71. auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]);
  72. if (!tensor1->ValueEqual(*tensor2)) {
  73. has_same = false;
  74. break;
  75. }
  76. }
  77. }
  78. if (!has_same) {
  79. continue;
  80. }
  81. auto fg_input = fg->parameters();
  82. auto cfg_input = cfg->parameters();
  83. if (fg_input.size() != cfg_input.size()) {
  84. continue;
  85. }
  86. // Compare input
  87. for (size_t i = 0; i < fg_input.size(); ++i) {
  88. if (!CompareNode(fg_input[i], cfg_input[i])) {
  89. has_same = false;
  90. break;
  91. }
  92. }
  93. if (!has_same) {
  94. continue;
  95. }
  96. // Compare output
  97. if (!CompareNode(fg->output(), cfg->output())) {
  98. continue;
  99. }
  100. // Find reusable fg
  101. new_fg = cfg;
  102. break;
  103. }
  104. if (new_fg != nullptr) {
  105. // Replace current fg with existing fg
  106. auto users = fg->func_graph_cnodes_index();
  107. for (auto &iter : users) {
  108. auto cnode = iter.first->first->cast<CNodePtr>();
  109. auto new_input = cnode->inputs();
  110. auto main_graph = cnode->func_graph();
  111. MS_EXCEPTION_IF_NULL(main_graph);
  112. if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
  113. new_input[1] = NewValueNode(new_fg);
  114. } else {
  115. new_input[0] = NewValueNode(new_fg);
  116. }
  117. auto new_cnode = main_graph->NewCNode(new_input);
  118. manager->Replace(iter.first->first, new_cnode);
  119. changed = true;
  120. }
  121. } else {
  122. // Add current fg to map
  123. graph_kernel_ops[key].push_back(fg);
  124. }
  125. }
  126. } else {
  127. graph_kernel_ops[key] = {fg};
  128. }
  129. }
  130. return changed;
  131. }
  132. bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
  133. MS_EXCEPTION_IF_NULL(manager);
  134. manager->AddFuncGraph(root);
  135. return DoReplace(manager);
  136. }
  137. } // namespace opt
  138. } // namespace mindspore