You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.h 4.7 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
  19. #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
  20. #include <string>
  21. #include <memory>
  22. #include <functional>
  23. #include <utility>
  24. #include <vector>
  25. #include "utils/hash_map.h"
  26. #include "backend/graph_compiler/vm.h"
  27. #include "ir/anf.h"
  28. #include "frontend/operator/ops.h"
  29. #include "backend/graph_compiler/segment_runner.h"
  30. #include "backend/graph_compiler/backend.h"
  31. #include "backend/graph_compiler/graph_partition.h"
  32. // mindspore namespace is the top level namespace of MindSpore project.
  33. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  34. namespace mindspore {
  35. extern const char kMsVm[];
  36. extern const char kGeVm[];
  37. // compile namespace
  38. // A sub namespace in ME to support compile related definition.
  39. namespace compile {
  40. extern std::vector<PrimitivePtr> nonlinear_ops;
  41. extern std::vector<PrimitivePtr> control_ops;
  42. const std::vector<PrimitivePtr> &GetMsNonlinearOps();
  43. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph);
  44. using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
  45. using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
  46. class CompileGraph {
  47. public:
  48. explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  49. virtual ~CompileGraph() = default;
  50. InstSet Run(const FuncGraphPtr &func_graph);
  51. bool IsCut(const AnfNodePtr &node);
  52. void Push(const AnfNodePtr &node);
  53. void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
  54. void Ret(int64_t nargs);
  55. virtual int64_t Ref(const AnfNodePtr &node);
  56. void set_height(int64_t h) {
  57. height_ = h;
  58. if (height_ > max_height_) {
  59. max_height_ = height_;
  60. }
  61. }
  62. void Reset() {
  63. height_ = 0;
  64. max_height_ = 0;
  65. slots_.clear();
  66. inst_.clear();
  67. }
  68. protected:
  69. virtual void PushParameters(const FuncGraphPtr &func_graph);
  70. bool Compile(const FuncGraphPtr &func_graph);
  71. int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
  72. int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
  73. virtual int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
  74. void AddPadStack(int64_t param_height);
  75. void AddTailCall(const AnfNodePtr &fn, size_t size);
  76. virtual void AddPartial(const CNodePtr &node);
  77. void AddMakeTuple(const CNodePtr &node);
  78. void AddSwitch(const CNodePtr &node);
  79. void AddSwitchLayer(const CNodePtr &node);
  80. void AddReturn(const CNodePtr &node);
  81. void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
  82. virtual void AddInput(const AnfNodePtr &node);
  83. virtual void AddExternal(const LinConvertResult &result);
  84. void AddInst(const Instruction &inst, const int64_t &arg);
  85. void AddInst(const Instruction &inst, const ValuePtr &arg);
  86. void AddInst(const Instruction &inst, const VectorRef &args);
  87. BackendPtr backend_;
  88. GraphPartitionPtr graph_partition_;
  89. LinkFuncType lin_convert_;
  90. int64_t height_{0};
  91. int64_t max_height_{0};
  92. mindspore::HashMap<AnfNodePtr, int64_t> slots_;
  93. InstSet inst_;
  94. };
  95. using CompileGraphPtr = std::shared_ptr<CompileGraph>;
  96. // CompileGraphs is used to Convert a graph cluster into instruction lists.
  97. class CompileGraphs {
  98. public:
  99. explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
  100. virtual ~CompileGraphs() = default;
  101. void Reset() {
  102. insts_.clear();
  103. mapping_.clear();
  104. }
  105. void Compile(const FuncGraphPtr &func_graph);
  106. FinalVMPtr Link();
  107. FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
  108. protected:
  109. InstSet insts_;
  110. mindspore::HashMap<FuncGraphPtr, int64_t> mapping_;
  111. CompileGraphPtr transform_;
  112. BackendPtr backend_;
  113. };
  114. BackendPtr CreateBackend();
  115. // Set mindRT whether enable. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
  116. void SetMindRTEnable();
  117. } // namespace compile
  118. } // namespace mindspore
  119. #endif // MINDSPORE_CCSRC_VM_TRANSFORM_H_