You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vm.cc 15 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/vm.h"
  19. #include <algorithm>
  20. #include "vm/vmimpl.h"
  21. #include "vm/backend.h"
  22. #include "pipeline/jit/parse/data_converter.h"
  23. #include "pybind_api/ir/base_ref_py.h"
  24. #include "pybind_api/ir/primitive_py.h"
  25. namespace mindspore {
  26. namespace compile {
  27. // Initialize StructPartial.
  28. // Arguments:
  29. // fn_: Callable function.
  30. // args_: Sequence of function args.
  31. // fg_: Graph of function.
  32. StructPartial::StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg)
  33. : fn_(fn), args_(args), fg_(fg) {}
  34. std::ostream &operator<<(std::ostream &os, const StructPartial &other) {
  35. os << "Partial(" << other.fn_ << ", " << other.args_.ToString() << ")";
  36. return os;
  37. }
  38. bool operator==(const StructPartial &lhs, const StructPartial &rhs) {
  39. return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_);
  40. }
  41. StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {}
  42. std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) {
  43. os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")";
  44. return os;
  45. }
  46. bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) {
  47. return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_);
  48. }
  49. std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) {
  50. os << "SwitchCondStatus(" << static_cast<int64_t>(other) << ")";
  51. return os;
  52. }
  53. // Follow the specified instructions to create a VM.
  54. // Arguments:
  55. // insts_: std::vector<std::map<std::string, VectorRef>>
  56. // insts_stack_: The value stack.
  57. // retp_: The call stack.
  58. // pc_: program counter (next instruction)
  59. // sp_: stack pointer (for the value stack)
  60. FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) {
  61. MS_LOG(DEBUG) << "InstSet size:" << insts_.size();
  62. insts_stack_.emplace_back(BaseRef());
  63. retp_.push(-1);
  64. }
  65. void FinalVM::Push(const BaseRef &v) {
  66. MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_;
  67. insts_stack_[IntToSize(sp_++)] = v;
  68. }
  69. void FinalVM::Pop(int64_t n) {
  70. if (n > sp_) {
  71. MS_LOG(EXCEPTION) << "Invalid value of n " << n << ", it should be not more than " << (sp_ - 1);
  72. }
  73. for (int64_t i = 0; i < n; i++) {
  74. insts_stack_[IntToSize(sp_ - i - 1)] = BaseRef();
  75. }
  76. sp_ -= n;
  77. }
  78. void FinalVM::MoveStack(int64_t nitems, int64_t height) {
  79. if (nitems > height || height > sp_) {
  80. MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height << " sp=" << sp_;
  81. }
  82. int64_t n = height - nitems;
  83. int64_t src = sp_ - height;
  84. int64_t dst = sp_ - nitems;
  85. for (int64_t i = 0; i < nitems; i++) {
  86. insts_stack_[IntToSize(src + i)] = insts_stack_[IntToSize(dst + i)];
  87. }
  88. Pop(n);
  89. }
  90. BaseRef FinalVM::Ref(int64_t i) {
  91. MS_LOG(DEBUG) << "Ref i:" << i << " sp_:" << sp_;
  92. size_t sp_next = LongToSize(sp_ + i);
  93. if (sp_next < insts_stack_.size()) {
  94. if (utils::isa<PyObjectRef>(insts_stack_[sp_next])) {
  95. py::object value = utils::cast<PyObjectRef>(insts_stack_[sp_next]).object_;
  96. MS_LOG(DEBUG) << "VM ref python:" << py::str(value);
  97. return parse::data_converter::PyDataToValue(value);
  98. }
  99. MS_LOG(DEBUG) << "Ref not python :" << insts_stack_[sp_next].ToString();
  100. return insts_stack_[sp_next];
  101. }
  102. MS_LOG(EXCEPTION) << "IndexError: index(" << sp_next << ") out of range [0, " << insts_stack_.size() << ").";
  103. }
  104. void FinalVM::Pushp() { retp_.push(pc_); }
  105. void FinalVM::Popp() {
  106. if (retp_.empty()) {
  107. MS_LOG(EXCEPTION) << "Stack retp_ is empty";
  108. }
  109. pc_ = retp_.top();
  110. MS_LOG(DEBUG) << "Pop pc:" << pc_ << ", sp:" << sp_;
  111. retp_.pop();
  112. }
  113. void FinalVM::Pushsp() { retsp_.push(sp_); }
  114. void FinalVM::Popsp() {
  115. int64_t sp = retsp_.top();
  116. MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
  117. if (sp_ >= sp) {
  118. Pop(sp_ - sp + 1);
  119. retsp_.pop();
  120. } else {
  121. MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must biger than sp:" << sp_;
  122. }
  123. }
  124. void FinalVM::DoJmp(const BaseRef &jmp_orig) {
  125. MS_LOG(DEBUG) << "Start";
  126. BaseRef jmp = jmp_orig;
  127. if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
  128. MS_LOG(DEBUG) << "Start jump StructPartial";
  129. auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
  130. auto args = new_jmp->args_;
  131. InstPadStack(VectorRef(std::vector<BaseRef>{static_cast<int64_t>(args.size())}));
  132. auto iter = args.rbegin();
  133. for (; iter != args.rend(); ++iter) {
  134. Push(*iter);
  135. }
  136. pc_ = new_jmp->fn_;
  137. return;
  138. }
  139. if (!utils::isa<int64_t>(jmp)) {
  140. MS_LOG(EXCEPTION) << "Jmp inst should be a int64_t";
  141. }
  142. pc_ = utils::cast<int64_t>(jmp);
  143. MS_LOG(DEBUG) << "End do jump pc_:" << pc_;
  144. }
  145. BaseRef FinalVM::Eval(const VectorRef &args) {
  146. MS_LOG(DEBUG) << "Start: " << args.size();
  147. insts_stack_.clear();
  148. insts_stack_.resize(args.size());
  149. std::stack<int64_t>().swap(retp_);
  150. retp_.push(-1);
  151. pc_ = 0;
  152. sp_ = 0;
  153. auto riter = args.rbegin();
  154. for (; riter != args.rend(); ++riter) {
  155. if (utils::isa<PyObjectRef>(*riter)) {
  156. PyObjectRef py_ref = utils::cast<PyObjectRef>(*riter);
  157. py::object value = py_ref.object_;
  158. if (py::isinstance<py::bool_>(value)) {
  159. auto a = py::cast<bool>(value);
  160. Push(static_cast<int64_t>(a));
  161. continue;
  162. }
  163. }
  164. Push(*riter);
  165. }
  166. while (pc_ >= 0) {
  167. auto inst = insts_[IntToSize(pc_)];
  168. MS_LOG(DEBUG) << "Loop " << insts_.size() << ", pc:" << pc_ << ", inst:" << inst_str[inst.first];
  169. ++pc_;
  170. auto iter = inst_function_map.find(inst.first);
  171. if (iter != inst_function_map.end()) {
  172. iter->second(inst.second);
  173. } else {
  174. MS_LOG(EXCEPTION) << "Unknown instruction {" << inst_str[inst.first] << "}";
  175. }
  176. }
  177. MS_LOG(DEBUG) << "End";
  178. return insts_stack_[0];
  179. }
  180. void FinalVM::InstCall(const VectorRef &args) {
  181. MS_LOG(DEBUG) << "Start";
  182. const size_t args_size = 1;
  183. if (args.size() != args_size) {
  184. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  185. << ".";
  186. return;
  187. }
  188. int64_t jmp = utils::cast<int64_t>(args[0]);
  189. MS_LOG(DEBUG) << "Call pushp:" << pc_ << ", jmp:" << jmp << ", sp:" << sp_;
  190. Pushp();
  191. DoJmp(Ref(jmp));
  192. MS_LOG(DEBUG) << "Instcall end sp :" << sp_;
  193. }
  194. void FinalVM::InstTailCall(const VectorRef &args) {
  195. MS_LOG(DEBUG) << "Start";
  196. const size_t args_size = 3;
  197. if (args.size() != args_size) {
  198. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  199. << ".";
  200. return;
  201. }
  202. const size_t jmp_index = 0;
  203. const size_t height_index = 1;
  204. const size_t nargs_index = 2;
  205. int64_t jmp = utils::cast<int64_t>(args[jmp_index]);
  206. int64_t height = utils::cast<int64_t>(args[height_index]);
  207. int64_t nargs = utils::cast<int64_t>(args[nargs_index]);
  208. auto new_jmp = Ref(jmp);
  209. MoveStack(nargs, height);
  210. MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp;
  211. DoJmp(new_jmp);
  212. MS_LOG(DEBUG) << "End";
  213. }
  214. void FinalVM::InstSwitchReturn(const VectorRef &args) {
  215. MS_LOG(DEBUG) << "Start";
  216. if (args.size() != 1) {
  217. MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
  218. return;
  219. }
  220. Pop(1);
  221. Popsp();
  222. }
  223. void FinalVM::InstReturn(const VectorRef &args) {
  224. MS_LOG(DEBUG) << "Start";
  225. const size_t args_size = 2;
  226. if (args.size() != args_size) {
  227. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  228. << ".";
  229. return;
  230. }
  231. int64_t rpos = utils::cast<int64_t>(args[0]);
  232. int64_t height = utils::cast<int64_t>(args[1]);
  233. auto rv = Ref(rpos);
  234. Pop(height);
  235. Push(rv);
  236. Popp();
  237. MS_LOG(DEBUG) << "End";
  238. }
  239. void FinalVM::InstRealPartial(const VectorRef &args) {
  240. const size_t args_size = 1;
  241. if (args.size() < args_size) {
  242. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
  243. << args.size() << ".";
  244. return;
  245. }
  246. int64_t fn_ = utils::cast<int64_t>(args[0]);
  247. auto fn = utils::cast<int64_t>(Ref(fn_));
  248. MS_LOG(DEBUG) << "Partial argssize:" << args.size();
  249. std::vector<BaseRef> outs(args.size() - 1);
  250. (void)std::transform(args.begin() + 1, args.end(), outs.begin(),
  251. [&, this](const BaseRef &a) { return Ref(utils::cast<int64_t>(a)); });
  252. Push(std::make_shared<StructPartial>(fn, VectorRef(outs)));
  253. }
  254. void FinalVM::InstPartial(const VectorRef &args) {
  255. MS_LOG(DEBUG) << "Start";
  256. InstRealPartial(args);
  257. MS_LOG(DEBUG) << "End";
  258. }
  259. void FinalVM::InstRealSwitch(const VectorRef &args) {
  260. const size_t args_size = 3;
  261. if (args.size() != args_size) {
  262. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  263. << ".";
  264. return;
  265. }
  266. const size_t cond_index = 0;
  267. const size_t vtrue_index = 1;
  268. const size_t vfalse_index = 2;
  269. int64_t cond = utils::cast<int64_t>(args[cond_index]);
  270. int64_t vtrue = utils::cast<int64_t>(args[vtrue_index]);
  271. int64_t vfalse = utils::cast<int64_t>(args[vfalse_index]);
  272. BaseRef c = Ref(cond);
  273. MS_LOG(DEBUG) << vtrue << " false:" << vfalse << " InstSwitch: " << c.ToString();
  274. bool bool_value = false;
  275. MS_EXCEPTION_IF_NULL(backend_);
  276. if (backend_->GetCond(c, &bool_value)) {
  277. MS_LOG(DEBUG) << "Cond:" << bool_value;
  278. if (bool_value) {
  279. Push(Ref(vtrue));
  280. } else {
  281. Push(Ref(vfalse));
  282. }
  283. } else {
  284. MS_LOG(EXCEPTION) << "Not supported type to be casted to bool";
  285. }
  286. }
  287. void FinalVM::InstSwitch(const VectorRef &args) {
  288. MS_LOG(DEBUG) << "Start";
  289. InstRealSwitch(args);
  290. MS_LOG(DEBUG) << "End";
  291. }
  292. void FinalVM::InstSwitchLayer(const VectorRef &args) {
  293. MS_LOG(DEBUG) << "Start";
  294. const size_t args_size = 2;
  295. if (args.size() != args_size) {
  296. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  297. << ".";
  298. return;
  299. }
  300. int64_t idx = utils::cast<int64_t>(args[0]);
  301. VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int64_t>(args[1])));
  302. int64_t size = static_cast<int64_t>(branches.size());
  303. BaseRef index = Ref(idx);
  304. int64_t idx_value = 0;
  305. MS_EXCEPTION_IF_NULL(backend_);
  306. if (!backend_->GetIndex(index, &idx_value)) {
  307. MS_LOG(EXCEPTION) << "Not supported type to be casted to int64_t.";
  308. }
  309. auto ori_value = idx_value;
  310. if (idx_value < 0) {
  311. // Add support negative index range [-size, -1].
  312. idx_value += size;
  313. }
  314. if (idx_value < 0 || idx_value >= size) {
  315. MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value
  316. << " out of range. Please make sure the value "
  317. << "of index in [" << -size << ", " << size << "), and the type is int32.";
  318. }
  319. Push(branches[idx_value]);
  320. MS_LOG(DEBUG) << "End";
  321. }
  322. void FinalVM::InstTuple(const VectorRef &args) {
  323. MS_LOG(DEBUG) << "Start";
  324. VectorRef tuple;
  325. auto iter = args.begin();
  326. for (; iter != args.end(); ++iter) {
  327. auto a = utils::cast<int64_t>(*iter);
  328. tuple.push_back(Ref(a));
  329. }
  330. Push(tuple);
  331. MS_LOG(DEBUG) << "End";
  332. }
  333. void FinalVM::InstPush(const VectorRef &args) {
  334. MS_LOG(DEBUG) << "Start";
  335. const size_t args_size = 1;
  336. if (args.size() != args_size) {
  337. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  338. << ".";
  339. return;
  340. }
  341. auto v = args[0];
  342. Push(v);
  343. MS_LOG(DEBUG) << "End";
  344. }
  345. void FinalVM::InstInput(const VectorRef &args) {
  346. MS_LOG(DEBUG) << "Start";
  347. const size_t args_size = 1;
  348. if (args.size() != args_size) {
  349. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  350. << ".";
  351. return;
  352. }
  353. int64_t rpos = utils::cast<int64_t>(args[0]);
  354. Push(Ref(rpos));
  355. MS_LOG(DEBUG) << "End";
  356. }
  357. void FinalVM::InstPadStack(const VectorRef &args) {
  358. MS_LOG(DEBUG) << "Start";
  359. const size_t args_size = 1;
  360. if (args.size() != args_size) {
  361. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  362. << ".";
  363. return;
  364. }
  365. int64_t sz = utils::cast<int64_t>(args[0]);
  366. MS_LOG(DEBUG) << insts_stack_.size() << " need padstack " << sz << " sp_ " << sp_;
  367. size_t stack_size = insts_stack_.size();
  368. int64_t need = sz - (static_cast<int64_t>(stack_size) - sp_);
  369. if (need > 0) {
  370. MS_LOG(DEBUG) << "InstPadStack resize: size:" << insts_stack_.size() << " need pad:" << need;
  371. insts_stack_.resize(stack_size + IntToSize(need));
  372. }
  373. MS_LOG(DEBUG) << "End";
  374. }
  375. void FinalVM::InstExternal(const VectorRef &args) {
  376. MS_LOG(DEBUG) << "Start:" << args.size();
  377. if (args.empty()) {
  378. MS_LOG(EXCEPTION) << "Args is empty!";
  379. }
  380. VectorRef tuple;
  381. RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
  382. compile::RunFuncPtr fn = run_ref.func_;
  383. const size_t arg_start_index = 2;
  384. for (size_t i = arg_start_index; i < args.size(); ++i) {
  385. auto index = utils::cast<int64_t>(args[i]);
  386. tuple.push_back(Ref(index));
  387. }
  388. if (!fn) {
  389. MS_LOG(EXCEPTION) << "Function not callable";
  390. }
  391. auto outs = (*fn)(tuple);
  392. MS_LOG(DEBUG) << "The 'fn' out size:" << outs.size();
  393. for (auto &o : outs) {
  394. MS_LOG(DEBUG) << "InstExternal value:" << o.ToString();
  395. Push(o);
  396. }
  397. MS_LOG(DEBUG) << "End";
  398. }
  399. void FinalVM::InstPushPrim(const VectorRef &args) {
  400. MS_LOG(DEBUG) << "Start: " << args.size();
  401. const size_t args_size = 2;
  402. if (args.size() < args_size) {
  403. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
  404. << args.size() << ".";
  405. return;
  406. }
  407. auto prim = utils::cast<PrimitivePtr>(args[0]);
  408. VectorRef tuple;
  409. for (size_t i = 1; i < args.size(); ++i) {
  410. auto index = utils::cast<int64_t>(args[i]);
  411. tuple.push_back(Ref(index));
  412. }
  413. if (prim->name() == kBpropCutOpName) {
  414. auto py_prim = prim->cast<PrimitivePyPtr>();
  415. MS_EXCEPTION_IF_NULL(py_prim);
  416. auto outs = py_prim->RunHookFunction(tuple);
  417. Push(outs);
  418. } else {
  419. auto outs = RunOperation(prim, tuple);
  420. Push(outs);
  421. }
  422. MS_LOG(DEBUG) << "End";
  423. }
  424. } // namespace compile
  425. } // namespace mindspore