You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vm.cc 15 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/vm.h"
  19. #include <algorithm>
  20. #include "vm/vmimpl.h"
  21. #include "vm/backend.h"
  22. #include "pipeline/jit/parse/data_converter.h"
  23. #include "pybind_api/ir/base_ref_py.h"
  24. namespace mindspore {
  25. namespace compile {
  26. // Initialize StructPartial.
  27. // Arguments:
  28. // fn_: Callable function.
  29. // args_: Sequence of function args.
  30. // fg_: Graph of function.
  31. StructPartial::StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg) : fn_(fn), args_(args), fg_(fg) {}
  32. std::ostream &operator<<(std::ostream &os, const StructPartial &other) {
  33. os << "partial(" << other.fn_ << ", " << other.args_.ToString() << ")";
  34. return os;
  35. }
  36. bool operator==(const StructPartial &lhs, const StructPartial &rhs) {
  37. return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_);
  38. }
  39. StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {}
  40. std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) {
  41. os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")";
  42. return os;
  43. }
  44. bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) {
  45. return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_);
  46. }
  47. std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) {
  48. os << "SwitchCondStatus(" << static_cast<int>(other) << ")";
  49. return os;
  50. }
  51. // Follow the specified instructions to create a VM.
  52. // Arguments:
  53. // insts_: std::vector<std::map<std::string, VectorRef>>
  54. // insts_stack_: The value stack.
  55. // retp_: The call stack.
  56. // pc_: program counter (next instruction)
  57. // sp_: stack pointer (for the value stack)
  58. FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) {
  59. MS_LOG(DEBUG) << "InstSet size:" << insts_.size();
  60. insts_stack_.emplace_back(BaseRef());
  61. retp_.push(-1);
  62. }
  63. void FinalVM::Push(const BaseRef &v) {
  64. MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_;
  65. insts_stack_[IntToSize(sp_++)] = v;
  66. }
  67. void FinalVM::Pop(int n) {
  68. if (n > sp_) {
  69. MS_LOG(EXCEPTION) << "Invalid value of n " << n << ", it should be not more than " << sp_ - 1;
  70. }
  71. for (int i = 0; i < n; i++) {
  72. insts_stack_[IntToSize(sp_ - i - 1)] = BaseRef();
  73. }
  74. sp_ -= n;
  75. }
  76. void FinalVM::MoveStack(int nitems, int height) {
  77. if (nitems > height || height > sp_) {
  78. MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height;
  79. }
  80. int n = height - nitems;
  81. int src = sp_ - height;
  82. int dst = sp_ - nitems;
  83. for (int i = 0; i < nitems; i++) {
  84. insts_stack_[IntToSize(src + i)] = insts_stack_[IntToSize(dst + i)];
  85. }
  86. Pop(n);
  87. }
  88. BaseRef FinalVM::Ref(int i) {
  89. MS_LOG(DEBUG) << "Ref i:" << i << " sp_:" << sp_;
  90. size_t sp_next = IntToSize(sp_ + i);
  91. if (sp_next < insts_stack_.size()) {
  92. if (utils::isa<PyObjectRef>(insts_stack_[sp_next])) {
  93. py::object value = utils::cast<PyObjectRef>(insts_stack_[sp_next]).object_;
  94. MS_LOG(DEBUG) << "VM ref python:" << py::str(value);
  95. return parse::data_converter::PyDataToValue(value);
  96. }
  97. MS_LOG(DEBUG) << "Ref not python :" << insts_stack_[sp_next].ToString();
  98. return insts_stack_[sp_next];
  99. }
  100. MS_LOG(EXCEPTION) << "IndexError: index(" << sp_next << ") out of range [0, " << insts_stack_.size() << ").";
  101. }
  102. void FinalVM::Pushp() { retp_.push(pc_); }
  103. void FinalVM::Popp() {
  104. if (retp_.empty()) {
  105. MS_LOG(EXCEPTION) << "Stack retp_ is empty";
  106. }
  107. pc_ = retp_.top();
  108. MS_LOG(DEBUG) << "Pop pc:" << pc_ << ", sp:" << sp_;
  109. retp_.pop();
  110. }
  111. void FinalVM::Pushsp() { retsp_.push(sp_); }
  112. void FinalVM::Popsp() {
  113. int sp = retsp_.top();
  114. MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
  115. if (sp_ >= sp) {
  116. Pop(sp_ - sp + 1);
  117. retsp_.pop();
  118. } else {
  119. MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must biger than sp:" << sp_;
  120. }
  121. }
  122. void FinalVM::DoJmp(const BaseRef &jmp_orig) {
  123. MS_LOG(DEBUG) << "Start";
  124. BaseRef jmp = jmp_orig;
  125. if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
  126. MS_LOG(DEBUG) << "Start jump StructPartial";
  127. auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
  128. auto args = new_jmp->args_;
  129. InstPadStack(VectorRef(std::vector<BaseRef>{static_cast<int>(args.size())}));
  130. auto iter = args.rbegin();
  131. for (; iter != args.rend(); ++iter) {
  132. Push(*iter);
  133. }
  134. pc_ = new_jmp->fn_;
  135. return;
  136. }
  137. if (!utils::isa<int>(jmp)) {
  138. MS_LOG(EXCEPTION) << "Jmp inst should be a int";
  139. }
  140. pc_ = utils::cast<int>(jmp);
  141. MS_LOG(DEBUG) << "End do jump pc_:" << pc_;
  142. }
  143. BaseRef FinalVM::Eval(const VectorRef &args) {
  144. MS_LOG(DEBUG) << "Start: " << args.size();
  145. insts_stack_.clear();
  146. insts_stack_.resize(args.size());
  147. std::stack<int>().swap(retp_);
  148. retp_.push(-1);
  149. pc_ = 0;
  150. sp_ = 0;
  151. auto riter = args.rbegin();
  152. for (; riter != args.rend(); ++riter) {
  153. if (utils::isa<PyObjectRef>(*riter)) {
  154. PyObjectRef py_ref = utils::cast<PyObjectRef>(*riter);
  155. py::object value = py_ref.object_;
  156. if (py::isinstance<py::bool_>(value)) {
  157. auto a = py::cast<bool>(value);
  158. Push(static_cast<int>(a));
  159. continue;
  160. }
  161. }
  162. Push(*riter);
  163. }
  164. while (pc_ >= 0) {
  165. auto inst = insts_[IntToSize(pc_)];
  166. MS_LOG(DEBUG) << "Loop " << insts_.size() << ", pc:" << pc_ << ", inst:" << inst_str[inst.first];
  167. ++pc_;
  168. auto iter = inst_function_map.find(inst.first);
  169. if (iter != inst_function_map.end()) {
  170. iter->second(inst.second);
  171. } else {
  172. MS_LOG(EXCEPTION) << "Unknown instruction {" << inst_str[inst.first] << "}";
  173. }
  174. }
  175. MS_LOG(DEBUG) << "End";
  176. return insts_stack_[0];
  177. }
  178. void FinalVM::InstCall(const VectorRef &args) {
  179. MS_LOG(DEBUG) << "Start";
  180. const size_t args_size = 1;
  181. if (args.size() != args_size) {
  182. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  183. << ".";
  184. return;
  185. }
  186. int jmp = utils::cast<int>(args[0]);
  187. MS_LOG(DEBUG) << "Call pushp:" << pc_ << ", jmp:" << jmp << ", sp:" << sp_;
  188. Pushp();
  189. DoJmp(Ref(jmp));
  190. MS_LOG(DEBUG) << "Instcall end sp :" << sp_;
  191. }
  192. void FinalVM::InstTailCall(const VectorRef &args) {
  193. MS_LOG(DEBUG) << "Start";
  194. const size_t args_size = 3;
  195. if (args.size() != args_size) {
  196. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  197. << ".";
  198. return;
  199. }
  200. int jmp = utils::cast<int>(args[0]);
  201. int height = utils::cast<int>(args[1]);
  202. int nargs = utils::cast<int>(args[2]);
  203. auto new_jmp = Ref(jmp);
  204. MoveStack(nargs, height);
  205. MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp;
  206. DoJmp(new_jmp);
  207. MS_LOG(DEBUG) << "End";
  208. }
  209. void FinalVM::InstSwitchReturn(const VectorRef &args) {
  210. MS_LOG(DEBUG) << "Start";
  211. if (args.size() != 1) {
  212. MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
  213. return;
  214. }
  215. Pop(1);
  216. Popsp();
  217. }
  218. void FinalVM::InstReturn(const VectorRef &args) {
  219. MS_LOG(DEBUG) << "Start";
  220. const size_t args_size = 2;
  221. if (args.size() != args_size) {
  222. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  223. << ".";
  224. return;
  225. }
  226. int rpos = utils::cast<int>(args[0]);
  227. int height = utils::cast<int>(args[1]);
  228. auto rv = Ref(rpos);
  229. Pop(height);
  230. Push(rv);
  231. Popp();
  232. MS_LOG(DEBUG) << "End";
  233. }
  234. void FinalVM::InstRealPartial(const VectorRef &args) {
  235. const size_t args_size = 1;
  236. if (args.size() < args_size) {
  237. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
  238. << args.size() << ".";
  239. return;
  240. }
  241. int fn_ = utils::cast<int>(args[0]);
  242. auto fn = utils::cast<int>(Ref(fn_));
  243. MS_LOG(DEBUG) << "Partial argssize:" << args.size();
  244. std::vector<BaseRef> outs(args.size() - 1);
  245. (void)std::transform(args.begin() + 1, args.end(), outs.begin(),
  246. [&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); });
  247. Push(std::make_shared<StructPartial>(fn, VectorRef(outs)));
  248. }
  249. void FinalVM::InstPartial(const VectorRef &args) {
  250. MS_LOG(DEBUG) << "Start";
  251. InstRealPartial(args);
  252. MS_LOG(DEBUG) << "End";
  253. }
  254. void FinalVM::InstRealSwitch(const VectorRef &args) {
  255. const size_t args_size = 3;
  256. if (args.size() != args_size) {
  257. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  258. << ".";
  259. return;
  260. }
  261. int cond = utils::cast<int>(args[0]);
  262. int vtrue = utils::cast<int>(args[1]);
  263. int vfalse = utils::cast<int>(args[2]);
  264. BaseRef c = Ref(cond);
  265. MS_LOG(DEBUG) << vtrue << " false:" << vfalse << " InstSwitch: " << c.ToString();
  266. bool bool_value = false;
  267. if (backend_->GetCond(c, &bool_value)) {
  268. MS_LOG(DEBUG) << "Cond:" << bool_value;
  269. if (bool_value) {
  270. Push(Ref(vtrue));
  271. } else {
  272. Push(Ref(vfalse));
  273. }
  274. } else {
  275. MS_LOG(EXCEPTION) << "Not supported type to be casted to bool";
  276. }
  277. }
  278. void FinalVM::InstSwitch(const VectorRef &args) {
  279. MS_LOG(DEBUG) << "Start";
  280. InstRealSwitch(args);
  281. MS_LOG(DEBUG) << "End";
  282. }
  283. void FinalVM::InstSwitchLayer(const VectorRef &args) {
  284. MS_LOG(DEBUG) << "Start";
  285. const size_t args_size = 2;
  286. if (args.size() != args_size) {
  287. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
  288. << ".";
  289. return;
  290. }
  291. int idx = utils::cast<int>(args[0]);
  292. VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int>(args[1])));
  293. int size = static_cast<int>(branches.size());
  294. BaseRef index = Ref(idx);
  295. int idx_value = 0;
  296. if (!backend_->GetIndex(index, &idx_value)) {
  297. MS_LOG(EXCEPTION) << "Not supported type to be casted to int.";
  298. }
  299. auto ori_value = idx_value;
  300. if (idx_value < 0) {
  301. // Add support negative index range [-size, -1].
  302. idx_value += size;
  303. }
  304. if (idx_value < 0 || idx_value >= size) {
  305. MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value
  306. << " out of range. Please make sure the value "
  307. << "of index in [" << -size << ", " << size << "), and the type is int32.";
  308. }
  309. Push(branches[idx_value]);
  310. MS_LOG(DEBUG) << "End";
  311. }
  312. void FinalVM::InstTuple(const VectorRef &args) {
  313. MS_LOG(DEBUG) << "Start";
  314. VectorRef tuple;
  315. auto iter = args.begin();
  316. for (; iter != args.end(); ++iter) {
  317. auto a = utils::cast<int>(*iter);
  318. tuple.push_back(Ref(a));
  319. }
  320. Push(tuple);
  321. MS_LOG(DEBUG) << "End";
  322. }
  323. void FinalVM::InstPush(const VectorRef &args) {
  324. MS_LOG(DEBUG) << "Start";
  325. const size_t args_size = 1;
  326. if (args.size() != args_size) {
  327. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  328. << ".";
  329. return;
  330. }
  331. auto v = args[0];
  332. Push(v);
  333. MS_LOG(DEBUG) << "End";
  334. }
  335. void FinalVM::InstInput(const VectorRef &args) {
  336. MS_LOG(DEBUG) << "Start";
  337. const size_t args_size = 1;
  338. if (args.size() != args_size) {
  339. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  340. << ".";
  341. return;
  342. }
  343. int rpos = utils::cast<int>(args[0]);
  344. Push(Ref(rpos));
  345. MS_LOG(DEBUG) << "End";
  346. }
  347. void FinalVM::InstPadStack(const VectorRef &args) {
  348. MS_LOG(DEBUG) << "Start";
  349. const size_t args_size = 1;
  350. if (args.size() != args_size) {
  351. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
  352. << ".";
  353. return;
  354. }
  355. int sz = utils::cast<int>(args[0]);
  356. MS_LOG(DEBUG) << insts_stack_.size() << " need padstack " << sz << " sp_ " << sp_;
  357. size_t stack_size = insts_stack_.size();
  358. int need = sz - (static_cast<int>(stack_size) - sp_);
  359. if (need > 0) {
  360. MS_LOG(DEBUG) << "InstPadStack resize: size:" << insts_stack_.size() << " need pad:" << need;
  361. insts_stack_.resize(stack_size + IntToSize(need));
  362. }
  363. MS_LOG(DEBUG) << "End";
  364. }
  365. void FinalVM::InstExternal(const VectorRef &args) {
  366. MS_LOG(DEBUG) << "Start:" << args.size();
  367. if (args.empty()) {
  368. MS_LOG(EXCEPTION) << "Args is empty!";
  369. }
  370. VectorRef tuple;
  371. RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
  372. compile::RunFuncPtr fn = run_ref.func_;
  373. for (size_t i = 2; i < args.size(); ++i) {
  374. auto index = utils::cast<int>(args[i]);
  375. tuple.push_back(Ref(index));
  376. }
  377. if (!fn) {
  378. MS_LOG(EXCEPTION) << "Function not callable";
  379. }
  380. auto outs = (*fn)(tuple);
  381. MS_LOG(DEBUG) << "'fn' out size:" << outs.size();
  382. for (auto &o : outs) {
  383. MS_LOG(DEBUG) << "InstExternal value:" << o.ToString();
  384. Push(o);
  385. }
  386. MS_LOG(DEBUG) << "End";
  387. }
  388. void FinalVM::InstPushPrim(const VectorRef &args) {
  389. MS_LOG(DEBUG) << "Start: " << args.size();
  390. const size_t args_size = 2;
  391. if (args.size() < args_size) {
  392. MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
  393. << args.size() << ".";
  394. return;
  395. }
  396. auto prim = utils::cast<PrimitivePtr>(args[0]);
  397. VectorRef tuple;
  398. for (size_t i = 1; i < args.size(); ++i) {
  399. auto index = utils::cast<int>(args[i]);
  400. tuple.push_back(Ref(index));
  401. }
  402. if (prim->name() == "bprop_cut") {
  403. auto outs = RunHook(prim, tuple);
  404. Push(outs);
  405. } else {
  406. auto outs = RunOperation(prim, tuple);
  407. Push(outs);
  408. }
  409. MS_LOG(DEBUG) << "End";
  410. }
  411. void FinalVM::SyncData(const py::object &arg) {
  412. if (py::isinstance<py::tuple>(arg)) {
  413. auto arg_list = py::cast<py::tuple>(arg);
  414. for (size_t i = 0; i < arg_list.size(); i++) {
  415. SyncData(arg_list[i]);
  416. }
  417. }
  418. if (py::isinstance<tensor::Tensor>(arg)) {
  419. auto tensor = py::cast<tensor::TensorPtr>(arg);
  420. (void)tensor->data_sync();
  421. }
  422. }
  423. BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
  424. MS_LOG(DEBUG) << "input for operation:";
  425. MS_EXCEPTION_IF_NULL(prim);
  426. return prim->RunHookFunction(args);
  427. }
  428. } // namespace compile
  429. } // namespace mindspore