You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ref_eliminate.h 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
  18. #include <memory>
  19. #include "ir/pattern_matcher.h"
  20. #include "frontend/optimizer/irpass.h"
  21. #include "frontend/optimizer/optimizer.h"
  22. namespace mindspore {
  23. namespace opt {
  24. namespace irpass {
  25. namespace internal {
  26. class GetRefValueTransform {
  27. public:
  28. GetRefValueTransform() {}
  29. ~GetRefValueTransform() = default;
  30. AnfNodePtr operator()(const AnfNodePtr &node) {
  31. CNodePtr cnode = node->cast<CNodePtr>();
  32. auto inputs = cnode->inputs();
  33. auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>();
  34. if (fg->recursive()) {
  35. MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString();
  36. return node;
  37. }
  38. auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue"));
  39. auto output = new_fg->output();
  40. new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output}));
  41. inputs[0] = NewValueNode(new_fg);
  42. auto ret_node = cnode->func_graph()->NewCNode(inputs);
  43. return ret_node;
  44. }
  45. };
  46. } // namespace internal
  47. // {prim::kPrimMakeRef, X, Y, Z} -> Y
  48. class MakeRefEliminater : public OptimizerCaller {
  49. public:
  50. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  51. PatternNode<AnfNodePtr> x, y, z;
  52. MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
  53. return nullptr;
  54. }
  55. };
  56. // {prim::kPrimGetRefValue, Parameter} -> Parameter
  57. class GetRefParamEliminater : public OptimizerCaller {
  58. public:
  59. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  60. PatternNode<AnfNodePtr> x;
  61. MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
  62. return nullptr;
  63. }
  64. };
  65. // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
  66. // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
  67. // {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f}
  68. class GetMakeRefEliminater : public OptimizerCaller {
  69. public:
  70. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  71. PatternNode<AnfNodePtr> x, y, z;
  72. MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
  73. MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
  74. MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node));
  75. internal::GetRefValueTransform trans;
  76. auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr {
  77. auto rep = trans(x.GetNode(node));
  78. if (rep != nullptr) {
  79. return rep;
  80. }
  81. return nullptr;
  82. };
  83. MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node));
  84. return nullptr;
  85. }
  86. };
  87. // IsValueNode<RefKey>
  88. class ReplaceRefkeyByParam : public OptimizerCaller {
  89. public:
  90. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
  91. auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
  92. auto refkey = GetValueNode<RefKeyPtr>(node);
  93. auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
  94. MS_EXCEPTION_IF_NULL(resource);
  95. auto top_graph = resource->func_graph();
  96. MS_EXCEPTION_IF_NULL(top_graph);
  97. for (const auto &tnode : top_graph->parameters()) {
  98. auto para = tnode->cast<ParameterPtr>();
  99. if (para != nullptr && para->name() == refkey->tag()) {
  100. return para;
  101. }
  102. }
  103. return nullptr;
  104. };
  105. PatternNode<AnfNodePtr> x;
  106. MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
  107. return nullptr;
  108. }
  109. };
  110. } // namespace irpass
  111. } // namespace opt
  112. } // namespace mindspore
  113. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_