You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ref_eliminate.h 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
  18. #include <memory>
  19. #include "optimizer/optimizer.h"
  20. #include "optimizer/irpass.h"
  21. #include "ir/visitor.h"
  22. #include "operator/ops.h"
  23. namespace mindspore {
  24. namespace opt {
  25. namespace irpass {
  26. // {prim::kPrimMakeRef, X, Y, Z} -> Y
  27. class MakeRefEliminater : public AnfVisitor {
  28. public:
  29. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  30. y_ = nullptr;
  31. auto gety = [this](const AnfNodePtr &node) -> bool {
  32. this->y_ = node;
  33. return true;
  34. };
  35. AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
  36. return y_;
  37. }
  38. void Visit(const AnfNodePtr &) override {}
  39. private:
  40. AnfNodePtr y_{nullptr};
  41. };
  42. // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
  43. // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
  44. class GetMakeRefEliminater : public AnfVisitor {
  45. public:
  46. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  47. auto cnode = node->cast<CNodePtr>();
  48. if (cnode == nullptr || cnode->size() != 2) {
  49. return nullptr;
  50. }
  51. // {prim::kPrimGetRefKey/Value, {...}}
  52. auto ref = cnode->input(1)->cast<CNodePtr>();
  53. if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) {
  54. return nullptr;
  55. }
  56. // {prim::kPrimMakeRef, X, Y, Z}
  57. if (cnode->IsApply(prim::kPrimGetRefKey)) {
  58. return ref->input(1);
  59. }
  60. if (cnode->IsApply(prim::kPrimGetRefValue)) {
  61. return ref->input(2);
  62. }
  63. return nullptr;
  64. }
  65. };
  66. // IsValueNode<RefKey>
  67. class ReplaceRefkeyByParam : public AnfVisitor {
  68. public:
  69. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
  70. if (!IsValueNode<RefKey>(node)) {
  71. return nullptr;
  72. }
  73. auto refkey = GetValueNode<RefKeyPtr>(node);
  74. auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
  75. MS_EXCEPTION_IF_NULL(resource);
  76. auto top_graph = resource->func_graph();
  77. MS_EXCEPTION_IF_NULL(top_graph);
  78. for (const auto &tnode : top_graph->parameters()) {
  79. auto para = tnode->cast<ParameterPtr>();
  80. if (para != nullptr && para->name() == refkey->tag()) {
  81. return para;
  82. }
  83. }
  84. return nullptr;
  85. }
  86. };
  87. } // namespace irpass
  88. } // namespace opt
  89. } // namespace mindspore
  90. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_