You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functionalize_while.h 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *conv_activation_fusion.h
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_
  17. #define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_
  18. #include <string>
  19. #include <set>
  20. #include <vector>
  21. #include <map>
  22. #include "backend/optimizer/common/pass.h"
  23. #include "tools/converter/converter_flags.h"
  24. #include "tools/optimizer/common/gllo_utils.h"
  25. #include "tools/optimizer/graph/functionalize_control_op_pass.h"
  26. using mindspore::lite::converter::FmkType;
  27. namespace mindspore::opt {
  28. constexpr const int POS_INVALID = -1;
  29. class FunctionalizeWhile {
  30. public:
  31. FunctionalizeWhile(std::vector<AnfNodePtr> node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg)
  32. : node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {}
  33. // while
  34. STATUS BuildWhileNode();
  35. STATUS IdentifyWhileNodeInput();
  36. STATUS IdentifyWhileNodeExternalInput();
  37. STATUS IdentifyWhileNodeOutput();
  38. STATUS UpdateExitNodeUser();
  39. STATUS NewWhileNode();
  40. STATUS InsertFuncGraphToWhileInput();
  41. bool WhileNodeExternalInputIsContain(const AnfNodePtr &node);
  42. // cond subgraph
  43. STATUS BuildCondGraph();
  44. STATUS CondSubgraphAddNodes();
  45. STATUS IdentifyCondSubgraphInput();
  46. STATUS IdentifyCondSubgraphOutput();
  47. // body subgraph
  48. STATUS BuildBodyGraph();
  49. STATUS BodySubgraphAddNodes();
  50. STATUS IdentifyBodySubgraphInput();
  51. STATUS IdentifyBodySubgraphOutput();
  52. CNodePtr BlongToWhichSwitch(const CNodePtr &node);
  53. CNodePtr BlongToWhichMerge(const CNodePtr &node);
  54. CNodePtr BlongToWhichEnter(const CNodePtr &node);
  55. CNodePtr BlongToWhichExternalEnter(const CNodePtr &node);
  56. int PosInInputEnterNodes(const CNodePtr &node);
  57. STATUS DropUselessNodesInMainGraph();
  58. STATUS Process();
  59. private:
  60. std::vector<AnfNodePtr> node_cluster_{};
  61. const CNodePtr loop_cond_node_;
  62. FuncGraphPtr fg_;
  63. FuncGraphPtr cond_sub_func_graph_ = nullptr;
  64. FuncGraphPtr body_sub_func_graph_ = nullptr;
  65. CNodePtr while_node_ = nullptr;
  66. std::string cond_subgraph_name_{};
  67. std::string body_subgraph_name_{};
  68. // while
  69. std::vector<CNodePtr> input_enter_nodes_{};
  70. std::vector<CNodePtr> external_input_enter_nodes_{};
  71. std::vector<CNodePtr> output_exit_nodes_{};
  72. // pair (next iteration node, next iteration node input)
  73. std::map<AnfNodePtr, AnfNodePtr> body_subgraph_output_map_{};
  74. // pair (switch node, switch output in body graph)
  75. std::map<AnfNodePtr, AnfNodePtr> body_subgraph_input_map_{};
  76. // pair (switch node, switch output in body graph)
  77. std::map<AnfNodePtr, AnfNodePtr> cond_subgraph_input_map_{};
  78. };
  79. } // namespace mindspore::opt
  80. #endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_PASS_H_