You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.h 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
  19. #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
  20. #include <string>
  21. #include <memory>
  22. #include <functional>
  23. #include <utility>
  24. #include <unordered_map>
  25. #include <vector>
  26. #include "vm/vm.h"
  27. #include "ir/anf.h"
  28. #include "frontend/operator/ops.h"
  29. #include "vm/segment_runner.h"
  30. #include "vm/backend.h"
  31. #include "vm/graph_partition.h"
  32. // mindspore namespace is the top level namespace of MindSpore project.
  33. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  34. namespace mindspore {
  35. extern const char kMsVm[];
  36. extern const char kGeVm[];
  37. // compile namespace
  38. // A sub namespace in ME to support compile related definition.
  39. namespace compile {
  40. extern std::vector<PrimitivePtr> nonlinear_ops;
  41. const std::vector<PrimitivePtr> &GetMsNonlinearOps();
  42. using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
  43. using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
  44. class CompileGraph {
  45. public:
  46. explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  47. ~CompileGraph() = default;
  48. InstSet Run(const FuncGraphPtr &func_graph);
  49. bool IsCut(const AnfNodePtr &node);
  50. void Push(const AnfNodePtr &node);
  51. void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
  52. void Ret(int64_t nargs);
  53. int64_t Ref(const AnfNodePtr &node);
  54. void set_height(int64_t h) {
  55. height_ = h;
  56. if (height_ > max_height_) {
  57. max_height_ = height_;
  58. }
  59. }
  60. void Reset() {
  61. height_ = 0;
  62. max_height_ = 0;
  63. slots_.clear();
  64. inst_.clear();
  65. }
  66. private:
  67. void PushParameters(const FuncGraphPtr &func_graph);
  68. bool Compile(const FuncGraphPtr &func_graph);
  69. int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
  70. int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
  71. int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
  72. void AddPadStack(int64_t param_height);
  73. void AddTailCall(const AnfNodePtr &fn, size_t size);
  74. void AddPartial(const CNodePtr &node);
  75. void AddMakeTuple(const CNodePtr &node);
  76. void AddSwitch(const CNodePtr &node);
  77. void AddSwitchLayer(const CNodePtr &node);
  78. void AddReturn(const CNodePtr &node);
  79. void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
  80. void AddInput(const AnfNodePtr &node);
  81. void AddExternal(const LinConvertResult &result);
  82. void AddInst(const Instruction &inst, const int64_t &arg);
  83. void AddInst(const Instruction &inst, const ValuePtr &arg);
  84. void AddInst(const Instruction &inst, const VectorRef &args);
  85. BackendPtr backend_;
  86. GraphPartitionPtr graph_partition_;
  87. LinkFuncType lin_convert_;
  88. int64_t height_{0};
  89. int64_t max_height_{0};
  90. std::unordered_map<AnfNodePtr, int64_t> slots_;
  91. InstSet inst_;
  92. };
  93. using CompileGraphPtr = std::shared_ptr<CompileGraph>;
  94. // CompileGraphs is used to Convert a graph cluster into instruction lists.
  95. class CompileGraphs {
  96. public:
  97. explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  98. ~CompileGraphs() = default;
  99. void Reset() {
  100. insts_.clear();
  101. mapping_.clear();
  102. }
  103. void Compile(const FuncGraphPtr &func_graph);
  104. FinalVMPtr Link(const FuncGraphPtr &func_graph);
  105. FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
  106. private:
  107. InstSet insts_;
  108. std::unordered_map<FuncGraphPtr, int64_t> mapping_;
  109. CompileGraphPtr transform_;
  110. BackendPtr backend_;
  111. };
  112. BackendPtr CreateBackend();
  113. } // namespace compile
  114. } // namespace mindspore
  115. #endif // MINDSPORE_CCSRC_VM_TRANSFORM_H_