You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.h 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
  19. #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
  20. #include <string>
  21. #include <memory>
  22. #include <functional>
  23. #include <utility>
  24. #include <unordered_map>
  25. #include <vector>
  26. #include "vm/vm.h"
  27. #include "ir/anf.h"
  28. #include "operator/ops.h"
  29. #include "vm/segment_runner.h"
  30. #include "vm/backend.h"
  31. // mindspore namespace is the top level namespace of Mindsporeession project.
  32. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  33. namespace mindspore {
  34. extern const char kMsVm[];
  35. extern const char kGeVm[];
  36. // compile namespace
  37. // A sub namespace in ME to support compile related definition.
  38. namespace compile {
  39. extern std::vector<PrimitivePtr> nonlinear_ops;
  40. const std::vector<PrimitivePtr> &GetMsNonlinearOps();
  41. using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
  42. using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
  43. class CompileGraph {
  44. public:
  45. explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  46. ~CompileGraph() = default;
  47. InstSet Run(const FuncGraphPtr &func_graph);
  48. InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph);
  49. bool IsCut(const AnfNodePtr &node);
  50. void Push(const AnfNodePtr &node);
  51. void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
  52. void Ret(int nargs);
  53. void GenMultiGraphsRun(const FuncGraphPtr &graph);
  54. int Ref(const AnfNodePtr &node);
  55. VectorRef SplitNodes(const FuncGraphPtr &func_graph);
  56. void set_height(int h) {
  57. height_ = h;
  58. if (height_ > max_height_) {
  59. max_height_ = height_;
  60. }
  61. }
  62. void Reset() {
  63. height_ = 0;
  64. max_height_ = 0;
  65. slots_.clear();
  66. inst_.clear();
  67. }
  68. private:
  69. void PushParameters(const FuncGraphPtr &func_graph);
  70. bool SplitGraph(const FuncGraphPtr &func_graph);
  71. int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list);
  72. int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
  73. int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
  74. void AddSinkSwitch(const CNodePtr &node);
  75. void AddPadStack(int param_height);
  76. void AddTailCall(const AnfNodePtr &fn, size_t size);
  77. void AddPartial(const CNodePtr &node);
  78. void AddMakeTuple(const CNodePtr &node);
  79. void AddSwitch(const CNodePtr &node);
  80. void AddReturn(const CNodePtr &node);
  81. void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
  82. void AddInput(const AnfNodePtr &node);
  83. void AddExternal(const LinConvertResult &result);
  84. void AddInst(const Instruction &inst, const int &arg);
  85. void AddInst(const Instruction &inst, const ValuePtr &arg);
  86. void AddInst(const Instruction &inst, const VectorRef &args);
  87. BackendPtr backend_;
  88. LinkFuncType lin_convert_;
  89. bool is_gevm_convert_;
  90. int height_{0};
  91. int max_height_{0};
  92. std::vector<PrimitivePtr> cut_list_;
  93. std::unordered_map<AnfNodePtr, int> slots_;
  94. InstSet inst_;
  95. };
  96. using CompileGraphPtr = std::shared_ptr<CompileGraph>;
  97. // CompileGraphs is used to Convert a graph cluster into instruction lists.
  98. class CompileGraphs {
  99. public:
  100. explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  101. ~CompileGraphs() = default;
  102. void Reset() {
  103. insts_.clear();
  104. mapping_.clear();
  105. }
  106. void Compile(const FuncGraphPtr &func_graph);
  107. FinalVMPtr Link(const FuncGraphPtr &func_graph);
  108. FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
  109. private:
  110. InstSet insts_;
  111. std::unordered_map<FuncGraphPtr, int> mapping_;
  112. CompileGraphPtr transform_;
  113. BackendPtr backend_;
  114. };
  115. BackendPtr CreateBackend();
  116. } // namespace compile
  117. } // namespace mindspore
  118. #endif // MINDSPORE_CCSRC_VM_TRANSFORM_H_