You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vm.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_VM_H_
  19. #define MINDSPORE_CCSRC_VM_VM_H_
  20. #include <map>
  21. #include <memory>
  22. #include <stack>
  23. #include <string>
  24. #include <tuple>
  25. #include <utility>
  26. #include <vector>
  27. #include <deque>
  28. #include <unordered_map>
  29. #include "ir/anf.h"
  30. #include "utils/base_ref_extends.h"
  31. namespace mindspore {
  32. namespace compile {
  33. class Backend;
  34. using BackendPtr = std::shared_ptr<Backend>;
  35. enum Instruction {
  36. kCall = 0,
  37. kTailCall,
  38. kReturn,
  39. kPartial,
  40. kSwitch,
  41. kSwitchReturn,
  42. kTuple,
  43. kInput,
  44. kExternal,
  45. kPush,
  46. kPrim,
  47. kGraph,
  48. kPadStack,
  49. kSwitchLayer
  50. };
  51. using InstType = std::pair<Instruction, VectorRef>;
  52. using InstSet = std::vector<InstType>;
  53. using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>;
  54. const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch",
  55. "switch_return", "tuple", "input", "external", "push",
  56. "primitive", "graph", "pad_stack", "switch_layer"};
  57. class StructPartial : public Base {
  58. public:
  59. // Initialize StructPartial.
  60. StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr);
  61. virtual ~StructPartial() = default;
  62. MS_DECLARE_PARENT(StructPartial, Base)
  63. int fn_;
  64. VectorRef args_;
  65. FuncGraphPtr fg_;
  66. };
  67. std::ostream &operator<<(std::ostream &os, const StructPartial &other);
  68. bool operator==(const StructPartial &lhs, const StructPartial &rhs);
  69. class StructSimuSwitch : public Base {
  70. public:
  71. StructSimuSwitch(const BaseRef &fn, const BaseRef &value);
  72. virtual ~StructSimuSwitch() = default;
  73. MS_DECLARE_PARENT(StructSimuSwitch, Base)
  74. BaseRef fn_;
  75. BaseRef value_;
  76. };
  77. std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other);
  78. bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs);
  79. class FinalVM {
  80. public:
  81. // Create a VM with the specified instructions and backend.
  82. explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
  83. virtual ~FinalVM() = default;
  84. BaseRef Eval(const VectorRef &args);
  85. void InstCall(const VectorRef &args);
  86. void InstTailCall(const VectorRef &args);
  87. void InstReturn(const VectorRef &args);
  88. void InstPartial(const VectorRef &args);
  89. void InstSimuPartial(const VectorRef &args);
  90. void InstRealPartial(const VectorRef &args);
  91. void InstSwitch(const VectorRef &args);
  92. void InstSimuSwitch(const VectorRef &args);
  93. void InstRealSwitch(const VectorRef &args);
  94. void InstTuple(const VectorRef &args);
  95. void InstPush(const VectorRef &args);
  96. void InstInput(const VectorRef &args);
  97. void InstPadStack(const VectorRef &args);
  98. void InstExternal(const VectorRef &args);
  99. void InstPushPrim(const VectorRef &args);
  100. void InstSwitchReturn(const VectorRef &args);
  101. void InstSwitchLayer(const VectorRef &args);
  102. void set_insts(const InstSet &value) { insts_ = value; }
  103. BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg);
  104. protected:
  105. BaseRef Ref(int i);
  106. void Push(const BaseRef &v);
  107. void Pop(int n = 1);
  108. void MoveStack(int nitems, int height);
  109. void Pushp();
  110. void Popp();
  111. void Pushsp();
  112. void Popsp();
  113. void PushStatus(bool is_switch_call);
  114. bool PopStatus();
  115. void DoJmp(const BaseRef &jmp);
  116. void SyncData(const py::object &args);
  117. void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
  118. BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
  119. private:
  120. InstSet insts_;
  121. std::deque<BaseRef> insts_stack_;
  122. std::stack<int> retp_;
  123. std::stack<int> retsp_;
  124. std::stack<bool> ret_status_;
  125. int pc_;
  126. int sp_;
  127. std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
  128. std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
  129. BackendPtr backend_;
  130. const InstFunctionMap inst_function_map = {
  131. {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
  132. {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }},
  133. {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }},
  134. {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }},
  135. {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }},
  136. {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }},
  137. {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }},
  138. {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }},
  139. {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }},
  140. {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }},
  141. {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
  142. {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
  143. {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}};
  144. };
  145. using FinalVMPtr = std::shared_ptr<FinalVM>;
  146. } // namespace compile
  147. } // namespace mindspore
  148. #endif // MINDSPORE_CCSRC_VM_VM_H_