You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.h 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
  19. #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
  20. #include <string>
  21. #include <memory>
  22. #include <functional>
  23. #include <utility>
  24. #include <unordered_map>
  25. #include <vector>
  26. #include "vm/vm.h"
  27. #include "ir/anf.h"
  28. #include "operator/ops.h"
  29. #include "vm/segment_runner.h"
  30. #include "vm/backend.h"
  31. // mindspore namespace is the top level namespace of Mindsporeession project.
  32. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  33. namespace mindspore {
  34. extern const char kMsVm[];
  35. extern const char kGeVm[];
  36. // compile namespace
  37. // A sub namespace in ME to support compile related definition.
  38. namespace compile {
  39. extern std::vector<PrimitivePtr> nonlinear_ops;
  40. const std::vector<PrimitivePtr> &GetMsNonlinearOps();
  41. using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
  42. using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
  43. class CompileGraph {
  44. public:
  45. explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  46. ~CompileGraph() = default;
  47. InstSet Run(const FuncGraphPtr &func_graph);
  48. InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph);
  49. bool IsCut(const AnfNodePtr &node);
  50. void Push(const AnfNodePtr &node);
  51. void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
  52. void Ret(int nargs);
  53. void GenMultiGraphsRun(const FuncGraphPtr &graph);
  54. int Ref(const AnfNodePtr &node);
  55. VectorRef SplitNodes(const FuncGraphPtr &func_graph);
  56. void set_height(int h) {
  57. height_ = h;
  58. if (height_ > max_height_) {
  59. max_height_ = height_;
  60. }
  61. }
  62. void Reset() {
  63. height_ = 0;
  64. max_height_ = 0;
  65. slots_.clear();
  66. inst_.clear();
  67. }
  68. private:
  69. void PushParameters(const FuncGraphPtr &func_graph);
  70. std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph);
  71. bool SplitGraph(const FuncGraphPtr &func_graph);
  72. int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
  73. int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
  74. int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
  75. void AddSinkSwitch(const CNodePtr &node);
  76. void AddPadStack(int param_height);
  77. void AddTailCall(const AnfNodePtr &fn, size_t size);
  78. void AddPartial(const CNodePtr &node);
  79. void AddMakeTuple(const CNodePtr &node);
  80. void AddSwitch(const CNodePtr &node);
  81. void AddReturn(const CNodePtr &node);
  82. void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
  83. void AddInput(const AnfNodePtr &node);
  84. void AddExternal(const LinConvertResult &result);
  85. void AddInst(const Instruction &inst, const int &arg);
  86. void AddInst(const Instruction &inst, const ValuePtr &arg);
  87. void AddInst(const Instruction &inst, const VectorRef &args);
  88. BackendPtr backend_;
  89. LinkFuncType lin_convert_;
  90. bool is_gevm_convert_;
  91. int height_{0};
  92. int max_height_{0};
  93. std::vector<PrimitivePtr> cut_list_;
  94. std::unordered_map<AnfNodePtr, int> slots_;
  95. InstSet inst_;
  96. };
  97. using CompileGraphPtr = std::shared_ptr<CompileGraph>;
  98. // CompileGraphs is used to Convert a graph cluster into instruction lists.
  99. class CompileGraphs {
  100. public:
  101. explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  102. ~CompileGraphs() = default;
  103. void Reset() {
  104. insts_.clear();
  105. mapping_.clear();
  106. }
  107. void Compile(const FuncGraphPtr &func_graph);
  108. FinalVMPtr Link(const FuncGraphPtr &func_graph);
  109. FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
  110. bool ContainMixedTarget(const FuncGraphPtr &graph);
  111. private:
  112. InstSet insts_;
  113. std::unordered_map<FuncGraphPtr, int> mapping_;
  114. CompileGraphPtr transform_;
  115. BackendPtr backend_;
  116. };
  117. BackendPtr CreateBackend();
  118. } // namespace compile
  119. } // namespace mindspore
  120. #endif // MINDSPORE_CCSRC_VM_TRANSFORM_H_