You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.h 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
  19. #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
  20. #include <string>
  21. #include <memory>
  22. #include <functional>
  23. #include <utility>
  24. #include <unordered_map>
  25. #include <vector>
  26. #include "vm/vm.h"
  27. #include "ir/anf.h"
  28. #include "frontend/operator/ops.h"
  29. #include "vm/segment_runner.h"
  30. #include "vm/backend.h"
  31. #include "vm/graph_partition.h"
  32. // mindspore namespace is the top level namespace of MindSpore project.
  33. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  34. namespace mindspore {
  35. extern const char kMsVm[];
  36. extern const char kGeVm[];
  37. // compile namespace
  38. // A sub namespace in ME to support compile related definition.
  39. namespace compile {
  40. extern std::vector<PrimitivePtr> nonlinear_ops;
  41. const std::vector<PrimitivePtr> &GetMsNonlinearOps();
  42. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph);
  43. using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
  44. using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
  45. class CompileGraph {
  46. public:
  47. explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  48. ~CompileGraph() = default;
  49. InstSet Run(const FuncGraphPtr &func_graph);
  50. bool IsCut(const AnfNodePtr &node);
  51. void Push(const AnfNodePtr &node);
  52. void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
  53. void Ret(int64_t nargs);
  54. int64_t Ref(const AnfNodePtr &node);
  55. void set_height(int64_t h) {
  56. height_ = h;
  57. if (height_ > max_height_) {
  58. max_height_ = height_;
  59. }
  60. }
  61. void Reset() {
  62. height_ = 0;
  63. max_height_ = 0;
  64. slots_.clear();
  65. inst_.clear();
  66. }
  67. private:
  68. void PushParameters(const FuncGraphPtr &func_graph);
  69. bool Compile(const FuncGraphPtr &func_graph);
  70. int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
  71. int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
  72. int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
  73. void AddPadStack(int64_t param_height);
  74. void AddTailCall(const AnfNodePtr &fn, size_t size);
  75. void AddPartial(const CNodePtr &node);
  76. void AddMakeTuple(const CNodePtr &node);
  77. void AddSwitch(const CNodePtr &node);
  78. void AddSwitchLayer(const CNodePtr &node);
  79. void AddReturn(const CNodePtr &node);
  80. void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
  81. void AddInput(const AnfNodePtr &node);
  82. void AddExternal(const LinConvertResult &result);
  83. void AddInst(const Instruction &inst, const int64_t &arg);
  84. void AddInst(const Instruction &inst, const ValuePtr &arg);
  85. void AddInst(const Instruction &inst, const VectorRef &args);
  86. BackendPtr backend_;
  87. GraphPartitionPtr graph_partition_;
  88. LinkFuncType lin_convert_;
  89. int64_t height_{0};
  90. int64_t max_height_{0};
  91. std::unordered_map<AnfNodePtr, int64_t> slots_;
  92. InstSet inst_;
  93. };
  94. using CompileGraphPtr = std::shared_ptr<CompileGraph>;
  95. // CompileGraphs is used to Convert a graph cluster into instruction lists.
  96. class CompileGraphs {
  97. public:
  98. explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  99. ~CompileGraphs() = default;
  100. void Reset() {
  101. insts_.clear();
  102. mapping_.clear();
  103. }
  104. void Compile(const FuncGraphPtr &func_graph);
  105. FinalVMPtr Link(const FuncGraphPtr &func_graph);
  106. FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
  107. private:
  108. InstSet insts_;
  109. std::unordered_map<FuncGraphPtr, int64_t> mapping_;
  110. CompileGraphPtr transform_;
  111. BackendPtr backend_;
  112. };
  113. // Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
  114. bool IsMindRTUsed();
  115. BackendPtr CreateBackend();
  116. } // namespace compile
  117. } // namespace mindspore
  118. #endif // MINDSPORE_CCSRC_VM_TRANSFORM_H_