You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_eliminate.cc 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/irpass/load_eliminate.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <set>
  20. #include <vector>
  21. #include "frontend/operator/ops.h"
  22. namespace mindspore::opt::irpass {
  23. // Covert:
  24. // load1 = load(para1, u1)
  25. // u2 = UpdateState(u1, load1)
  26. // ...
  27. // load2 = load(load1, u3)
  28. // u4 = UpdateState(u3, load2)
  29. // To:
  30. // load1 = load(para1, u1)
  31. // u2 = UpdateState(u1, load1)
  32. // ...
  33. // load2 = load(para1, u3) # load1 replaced by para1
  34. // u4 = UpdateState(u3, load2)
  35. AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  36. auto load_node = dyn_cast<CNode>(node);
  37. if (load_node == nullptr || load_node->inputs().empty()) {
  38. MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString();
  39. return nullptr;
  40. }
  41. auto load_cnode = load_node->cast<CNodePtr>();
  42. constexpr size_t kFirstInputIndex = 1;
  43. constexpr size_t kSecondInputIndex = 2;
  44. auto &input_load = load_cnode->input(kFirstInputIndex);
  45. if (IsPrimitiveCNode(input_load, prim::kPrimLoad)) {
  46. auto load_prim = NewValueNode(prim::kPrimLoad);
  47. auto input_load_cnode = input_load->cast<CNodePtr>();
  48. auto replace_input = input_load_cnode->input(kFirstInputIndex);
  49. auto monad = load_cnode->input(kSecondInputIndex);
  50. std::vector<AnfNodePtr> new_load_inputs = {load_prim, replace_input, monad};
  51. auto fg = load_cnode->func_graph();
  52. MS_EXCEPTION_IF_NULL(fg);
  53. auto new_load = fg->NewCNode(new_load_inputs);
  54. new_load->set_abstract(load_cnode->abstract());
  55. new_load->set_scope(load_cnode->scope());
  56. return new_load;
  57. }
  58. return nullptr;
  59. }
  60. } // namespace mindspore::opt::irpass