You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.cc 22 kB

4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/transform.h"
  19. #include <algorithm>
  20. #include <map>
  21. #include <queue>
  22. #include <string>
  23. #include <vector>
  24. #include "abstract/abstract_value.h"
  25. #ifdef ENABLE_GE
  26. #include "transform/graph_ir/convert.h"
  27. #endif
  28. #include "ir/graph_utils.h"
  29. #include "utils/ms_context.h"
  30. #include "debug/trace.h"
  31. #include "debug/anf_ir_dump.h"
  32. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  33. #include "ps/ps_context.h"
  34. #endif
  35. namespace mindspore {
  36. namespace compile {
  37. using mindspore::abstract::AbstractFunction;
  38. using mindspore::abstract::AbstractFunctionPtr;
  39. using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
  40. using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
  41. using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
  42. std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  43. prim::kPrimMakeTuple, prim::kPrimBpropCut};
  44. std::vector<PrimitivePtr> control_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple,
  45. prim::kPrimSwitchLayer};
  46. const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
  47. static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
  48. prim::kPrimSwitch, prim::kPrimMakeTuple,
  49. prim::kPrimBpropCut, prim::kPrimSwitchLayer};
  50. return ms_nonlinear_ops;
  51. }
  52. CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
  53. MS_EXCEPTION_IF_NULL(backend_);
  54. lin_convert_ = backend_->convert_fn();
  55. if (lin_convert_ == nullptr) {
  56. MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
  57. }
  58. graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
  59. }
  60. // Push the value node on the stack.
  61. void CompileGraph::Push(const AnfNodePtr &node) {
  62. MS_EXCEPTION_IF_NULL(node);
  63. if (slots_.count(node) > 0) {
  64. MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString()
  65. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  66. return;
  67. }
  68. MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
  69. << " is parameter: " << node->isa<Parameter>();
  70. slots_[node] = height_;
  71. set_height(height_ + 1);
  72. }
  73. void CompileGraph::AddInst(const Instruction &inst, const int64_t &arg) {
  74. VectorRef args;
  75. args.push_back(arg);
  76. AddInst(inst, args);
  77. }
  78. void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
  79. VectorRef args;
  80. args.push_back(arg);
  81. AddInst(inst, args);
  82. }
  83. void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
  84. inst_.push_back(std::make_pair(inst, args));
  85. }
  86. // Gets the stack reference for the node value. If the node is a constant,
  87. // it may actually cause the push in to not be mentioned before.
  88. int64_t CompileGraph::Ref(const AnfNodePtr &node) {
  89. MS_EXCEPTION_IF_NULL(node);
  90. MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
  91. if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
  92. if (IsValueNode<FuncGraph>(node)) {
  93. MS_LOG(DEBUG) << "Push graph.";
  94. AddInst(Instruction::kGraph, GetValueNode(node));
  95. } else {
  96. MS_LOG(DEBUG) << "Push.";
  97. if (IsValueNode<Primitive>(node)) {
  98. MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  99. } else {
  100. AddInst(Instruction::kPush, GetValueNode(node));
  101. }
  102. }
  103. Push(node);
  104. }
  105. MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
  106. << ", return: " << slots_[node] - height_;
  107. return slots_[node] - height_;
  108. }
  109. // Make sure the value of node is at the top of the stack.
  110. void CompileGraph::AddInput(const AnfNodePtr &node) {
  111. MS_EXCEPTION_IF_NULL(node);
  112. if (slots_.count(node) == 0) {
  113. MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
  114. (void)Ref(node);
  115. return;
  116. }
  117. AddInst(Instruction::kInput, Ref(node));
  118. set_height(height_ + 1);
  119. }
  120. // Call back effect in stack
  121. void CompileGraph::Ret(int64_t nargs) { set_height(height_ - nargs); }
  122. void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
  123. MS_EXCEPTION_IF_NULL(graph);
  124. std::vector<AnfNodePtr> parameters = graph->parameters();
  125. for (size_t i = parameters.size(); i != 0; i--) {
  126. MS_EXCEPTION_IF_NULL(parameters[i - 1]);
  127. Push(parameters[i - 1]);
  128. MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << parameters[i - 1]->DebugString(true);
  129. }
  130. }
  131. int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
  132. MS_EXCEPTION_IF_NULL(segment);
  133. MS_LOG(DEBUG) << "LinConvert start";
  134. LinConvertResult result;
  135. result = lin_convert_(segment, target);
  136. if (result.run == nullptr) {
  137. MS_LOG(ERROR) << "LinConvert failed";
  138. return RET_FAILED;
  139. }
  140. if (!(*result.run)) {
  141. if (result.inputs.size() != result.outputs.size()) {
  142. MS_EXCEPTION_IF_NULL(graph);
  143. MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  144. } else {
  145. size_t size = result.inputs.size();
  146. for (size_t i = 0; i < size; i++) {
  147. Tie(result.inputs[i], result.outputs[i]);
  148. }
  149. return RET_CONTINUE;
  150. }
  151. }
  152. AddExternal(result);
  153. return RET_SUCCESS;
  154. }
  155. int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
  156. MS_EXCEPTION_IF_NULL(node);
  157. MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
  158. std::vector<AnfNodePtr> node_inputs = node->inputs();
  159. if (node_inputs.empty()) {
  160. MS_LOG(EXCEPTION) << "The node->inputs() is empty";
  161. }
  162. AnfNodePtr fn = node_inputs[0];
  163. if (IsValueNode<Primitive>(fn)) {
  164. PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
  165. MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
  166. for (size_t i = node_inputs.size() - 1; i > 0; i--) {
  167. AddInput(node->input(i));
  168. }
  169. if (IsPrimitive(fn, prim::kPrimReturn)) {
  170. AddReturn(node);
  171. return RET_BREAK;
  172. }
  173. if (IsPrimitive(fn, prim::kPrimPartial)) {
  174. AddPartial(node);
  175. } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
  176. AddSwitch(node);
  177. } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
  178. AddSwitchLayer(node);
  179. } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
  180. AddMakeTuple(node);
  181. } else {
  182. AddPrimitive(node, value);
  183. }
  184. } else {
  185. int64_t ret = AddCall(graph, node);
  186. if (ret == RET_BREAK) {
  187. return ret;
  188. }
  189. }
  190. Push(node);
  191. return RET_SUCCESS;
  192. }
  193. bool CompileGraph::Compile(const FuncGraphPtr &graph) {
  194. MS_LOG(DEBUG) << "Start split graph";
  195. MS_EXCEPTION_IF_NULL(graph);
  196. MS_EXCEPTION_IF_NULL(graph_partition_);
  197. auto segments = graph_partition_->Partition(graph);
  198. MS_LOG(DEBUG) << "Split nodes size:" << segments.size();
  199. for (auto &segment : segments) {
  200. MS_EXCEPTION_IF_NULL(segment);
  201. int64_t ret = RET_SUCCESS;
  202. if (!segment->is_cut_) {
  203. MS_LOG(DEBUG) << "Start a extern LinConvert";
  204. if (!segment->nodes_.empty()) {
  205. std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
  206. ret = LinConvert(graph, segment, cur_target);
  207. } else {
  208. ret = LinConvert(graph, segment);
  209. }
  210. MS_LOG(DEBUG) << "End a extern LinConvert";
  211. if (ret == RET_FAILED) {
  212. return false;
  213. }
  214. if (ret == RET_CONTINUE) {
  215. continue;
  216. }
  217. } else if (!segment->nodes_.empty()) {
  218. MS_LOG(DEBUG) << "Start a cut node";
  219. auto &cut_node = segment->nodes_[0];
  220. MS_EXCEPTION_IF_NULL(cut_node);
  221. if (!cut_node->isa<CNode>()) {
  222. MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  223. }
  224. auto node = cut_node->cast<CNodePtr>();
  225. ret = InterpretNode(graph, node);
  226. MS_LOG(DEBUG) << "End a cut node";
  227. if (ret == RET_BREAK) {
  228. break;
  229. }
  230. }
  231. }
  232. MS_LOG(DEBUG) << "End split graph";
  233. return true;
  234. }
  235. InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
  236. MS_EXCEPTION_IF_NULL(graph);
  237. Reset();
  238. PushParameters(graph);
  239. int64_t param_height = height_;
  240. MS_EXCEPTION_IF_NULL(graph->get_return());
  241. MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
  242. if (!Compile(graph)) {
  243. return inst_;
  244. }
  245. AddPadStack(param_height);
  246. auto ret = inst_;
  247. Reset();
  248. return ret;
  249. }
  250. void CompileGraph::AddPadStack(int64_t param_height) {
  251. int64_t stack_sizes = max_height_ - param_height;
  252. MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
  253. << " need_stack:" << stack_sizes;
  254. if (stack_sizes > 0) {
  255. VectorRef need_stacks({stack_sizes});
  256. (void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
  257. }
  258. }
  259. void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
  260. VectorRef args;
  261. args.emplace_back(Ref(fn));
  262. args.emplace_back(height_);
  263. args.emplace_back(static_cast<int64_t>(size - 1));
  264. MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
  265. AddInst(Instruction::kTailCall, args);
  266. }
  267. void CompileGraph::AddPartial(const CNodePtr &node) {
  268. MS_EXCEPTION_IF_NULL(node);
  269. auto inputs = node->inputs();
  270. VectorRef args;
  271. if (inputs.size() <= 1) {
  272. MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
  273. }
  274. auto fn = inputs[1];
  275. if (!IsValueNode<FuncGraph>(fn)) {
  276. MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
  277. }
  278. for (size_t i = 1; i < inputs.size(); i++) {
  279. args.emplace_back(Ref(inputs[i]));
  280. }
  281. AddInst(Instruction::kPartial, args);
  282. }
  283. void CompileGraph::AddMakeTuple(const CNodePtr &node) {
  284. MS_EXCEPTION_IF_NULL(node);
  285. auto inputs = node->inputs();
  286. VectorRef args;
  287. for (size_t i = 1; i < inputs.size(); i++) {
  288. args.emplace_back(Ref(inputs[i]));
  289. }
  290. AddInst(Instruction::kTuple, args);
  291. }
  292. void CompileGraph::AddSwitch(const CNodePtr &node) {
  293. MS_EXCEPTION_IF_NULL(node);
  294. auto inputs = node->inputs();
  295. if (inputs.size() < kSwitchInputSize) {
  296. MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
  297. }
  298. VectorRef args;
  299. args.emplace_back(Ref(inputs[kCallKernelGraphIndex]));
  300. args.emplace_back(Ref(inputs[kSwitchTrueKernelGraphIndex]));
  301. args.emplace_back(Ref(inputs[kSwitchFalseKernelGraphIndex]));
  302. AddInst(Instruction::kSwitch, args);
  303. }
  304. void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
  305. MS_EXCEPTION_IF_NULL(node);
  306. auto inputs = node->inputs();
  307. if (inputs.size() != kSwitchLayerInputSize) {
  308. MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
  309. }
  310. VectorRef args;
  311. const size_t cond_index = 1;
  312. const size_t tuple_index = 2;
  313. args.emplace_back(Ref(inputs[cond_index]));
  314. args.emplace_back(Ref(inputs[tuple_index]));
  315. AddInst(Instruction::kSwitchLayer, args);
  316. }
  317. void CompileGraph::AddReturn(const CNodePtr &node) {
  318. MS_EXCEPTION_IF_NULL(node);
  319. VectorRef args;
  320. if (node->inputs().size() <= 1) {
  321. MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
  322. }
  323. args.emplace_back(Ref(node->input(1)));
  324. args.emplace_back(height_);
  325. AddInst(Instruction::kReturn, args);
  326. }
  327. void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
  328. MS_EXCEPTION_IF_NULL(node);
  329. auto inputs = node->inputs();
  330. VectorRef args;
  331. args.push_back(prim);
  332. for (size_t i = 1; i < inputs.size(); i++) {
  333. args.emplace_back(Ref(inputs[i]));
  334. }
  335. AddInst(Instruction::kPrim, args);
  336. }
  337. int64_t CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
  338. MS_EXCEPTION_IF_NULL(graph);
  339. MS_EXCEPTION_IF_NULL(node);
  340. auto inputs = node->inputs();
  341. if (inputs.empty()) {
  342. MS_LOG(EXCEPTION) << "The node->inputs() is empty.";
  343. }
  344. AnfNodePtr fn = inputs[0];
  345. (void)Ref(fn);
  346. size_t size = inputs.size();
  347. for (size_t i = size - 1; i > 0; i--) {
  348. AddInput(inputs[i]);
  349. }
  350. if (node == graph->output()) {
  351. AddTailCall(fn, size);
  352. return RET_BREAK;
  353. }
  354. MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
  355. AddInst(Instruction::kCall, Ref(fn));
  356. Ret(static_cast<int64_t>(size - 1));
  357. for (size_t i = size - 1; i > 0; i--) {
  358. const auto iter = slots_.find(inputs[i]);
  359. if (iter != slots_.end() && iter->second >= height_) {
  360. slots_.erase(inputs[i]);
  361. }
  362. }
  363. return RET_SUCCESS;
  364. }
  365. void CompileGraph::AddExternal(const LinConvertResult &result) {
  366. VectorRef args;
  367. args.push_back(result.run);
  368. args.push_back(result.simu_run);
  369. size_t size = result.inputs.size();
  370. for (size_t i = 0; i < size; i++) {
  371. args.emplace_back(Ref(result.inputs[i]));
  372. }
  373. AddInst(Instruction::kExternal, args);
  374. for (auto &out : result.outputs) {
  375. Push(out);
  376. }
  377. }
  378. void TraverseGraphMap(
  379. const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs,
  380. const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
  381. MS_EXCEPTION_IF_NULL(manager_ptr);
  382. MS_EXCEPTION_IF_NULL(tr);
  383. for (const auto &fg : fgs) {
  384. MS_EXCEPTION_IF_NULL(fg);
  385. for (const auto &ct_any : fg->value_nodes()) {
  386. AnfNodePtr const_primitive_node = ct_any.first;
  387. if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
  388. auto users = manager_ptr->node_users()[const_primitive_node];
  389. for (auto &use : users) {
  390. CNodePtr node = use.first->cast<CNodePtr>();
  391. MS_EXCEPTION_IF_NULL(node);
  392. if (node->func_graph() != fg) {
  393. continue;
  394. }
  395. int64_t key = use.second;
  396. if (key != 0) {
  397. MS_EXCEPTION_IF_NULL(node->input(0));
  398. bool key_is_const = node->input(0)->isa<ValueNode>();
  399. PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
  400. if (value != nullptr) {
  401. bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
  402. bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
  403. if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
  404. continue;
  405. }
  406. }
  407. FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
  408. dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
  409. tr->SetEdge(node, key, NewValueNode(g));
  410. }
  411. }
  412. }
  413. }
  414. }
  415. }
  416. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
  417. MS_EXCEPTION_IF_NULL(graph);
  418. FuncGraphManagerPtr manager_ptr = graph->manager();
  419. MS_EXCEPTION_IF_NULL(manager_ptr);
  420. MapPrimTypeFuncGraph prim_graphs;
  421. auto get_prim_graph = [&prim_graphs](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
  422. PrimTypePair prim_type = std::make_pair(prim, type);
  423. if (prim_graphs.end() == prim_graphs.find(prim_type)) {
  424. FuncGraphPtr g = std::make_shared<FuncGraph>();
  425. std::vector<AnfNodePtr> args;
  426. ValueNodePtr prim_ct = NewValueNode(prim);
  427. MS_EXCEPTION_IF_NULL(prim_ct);
  428. prim_ct->set_abstract(type);
  429. args.push_back(prim_ct);
  430. MS_EXCEPTION_IF_NULL(type);
  431. TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
  432. MS_EXCEPTION_IF_NULL(tp);
  433. MS_EXCEPTION_IF_NULL(g);
  434. for (auto t : tp->args_spec_list()) {
  435. ParameterPtr p = g->add_parameter();
  436. p->set_abstract(t);
  437. args.push_back(p);
  438. }
  439. AnfNodePtr out = g->NewCNode(args);
  440. out->set_abstract(tp->output());
  441. g->set_output(out);
  442. prim_graphs[prim_type] = g;
  443. }
  444. return prim_graphs[prim_type];
  445. };
  446. FuncGraphTransaction tr = manager_ptr->Transact();
  447. auto &fgs = manager_ptr->func_graphs();
  448. TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
  449. tr.Commit();
  450. return graph;
  451. }
  452. CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
  453. MS_EXCEPTION_IF_NULL(backend);
  454. MS_LOG(DEBUG) << "Start vm: " << backend->name();
  455. transform_ = std::make_shared<CompileGraph>(backend, cut_list);
  456. Reset();
  457. }
  458. // Convert graphs to unlinked instructions.
  459. void CompileGraphs::Compile(const FuncGraphPtr &graph) {
  460. MS_LOG(DEBUG) << "Start";
  461. mapping_[graph] = static_cast<int64_t>(insts_.size());
  462. if (transform_ != nullptr) {
  463. InstSet insts = transform_->Run(graph);
  464. if (!insts.empty()) {
  465. (void)insts_.insert(insts_.end(), insts.begin(), insts.end());
  466. }
  467. }
  468. MS_LOG(DEBUG) << "End";
  469. }
  470. // Link instructions from multiple function graphs together.
  471. FinalVMPtr CompileGraphs::Link() {
  472. MS_LOG(DEBUG) << "Start";
  473. for (std::size_t i = 0; i < insts_.size(); i++) {
  474. InstType inst = insts_[i];
  475. MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
  476. if (Instruction::kGraph == inst.first) {
  477. if (inst.second.empty()) {
  478. MS_LOG(EXCEPTION) << "The second element of inst is empty";
  479. }
  480. FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
  481. MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
  482. insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
  483. }
  484. }
  485. FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
  486. MS_LOG(DEBUG) << "End";
  487. return rt;
  488. }
  489. // Convert all graphs to unlinked instructions and link them.
  490. FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
  491. MS_EXCEPTION_IF_NULL(graph);
  492. MS_LOG(DEBUG) << "Start";
  493. Reset();
  494. MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
  495. FuncGraphPtr prim_graph = WrapPrimitives(graph);
  496. Compile(prim_graph);
  497. MS_EXCEPTION_IF_NULL(prim_graph);
  498. MS_EXCEPTION_IF_NULL(prim_graph->manager());
  499. FuncGraphSet graphs = prim_graph->manager()->func_graphs();
  500. for (auto g : graphs) {
  501. if (g != graph && g != nullptr) {
  502. Compile(g);
  503. }
  504. }
  505. FinalVMPtr rt = Link();
  506. Reset();
  507. MS_LOG(DEBUG) << "End";
  508. return rt;
  509. }
  510. BackendPtr CreateBackend() {
  511. auto context_ptr = MsContext::GetInstance();
  512. MS_EXCEPTION_IF_NULL(context_ptr);
  513. std::string name = context_ptr->backend_policy();
  514. MS_LOG(INFO) << "CreateBackend is: " << name;
  515. if (backend_list.count(name) == 0) {
  516. MS_LOG(EXCEPTION) << "Backend is error: " << name;
  517. }
  518. if (name == kMsConvert) {
  519. std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  520. uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  521. BackendPtr backend = nullptr;
  522. // Create MindRTBackend or MsBackend according to whether mindrt is used.
  523. if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  524. backend = std::make_shared<MindRTBackend>(name, target, device_id);
  525. } else {
  526. backend = std::make_shared<MsBackend>(name, target, device_id);
  527. }
  528. if (target == kAscendDevice) {
  529. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  530. backend->set_is_multi_graph_sink(false);
  531. context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
  532. } else {
  533. auto single_op = common::GetEnv(kGraphOpRun);
  534. if (single_op == "1") {
  535. context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
  536. }
  537. auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
  538. if (enable_mem_scheduler == "1") {
  539. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, true);
  540. context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
  541. }
  542. }
  543. }
  544. return backend;
  545. }
  546. return std::make_shared<Backend>(name);
  547. }
  548. void SetMindRTEnable() {
  549. auto context_ptr = MsContext::GetInstance();
  550. MS_EXCEPTION_IF_NULL(context_ptr);
  551. if (context_ptr->get_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT)) {
  552. return;
  553. }
  554. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  555. if (ps::PSContext::instance()->is_ps_mode()) {
  556. return;
  557. }
  558. #endif
  559. std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  560. if (common::GetEnv("ENABLE_ASCEND_MINDRT") == "1" || common::kEnableAscendMindRT) {
  561. // exception scenario: still run original process after enable ascend mindrt
  562. auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
  563. bool is_pynative_infer = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
  564. if (target == kAscendDevice && (mode == kPynativeMode || is_pynative_infer)) {
  565. context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
  566. return;
  567. }
  568. if ((common::GetEnv(kGraphOpRun) == "1" || common::GetEnv(kEnableMemScheduler) == "1") && target == kAscendDevice) {
  569. return;
  570. }
  571. } else {
  572. if ((target != kGPUDevice) && (target != kCPUDevice)) {
  573. return;
  574. }
  575. }
  576. #if defined(_WIN32) || defined(_WIN64)
  577. return;
  578. #endif
  579. MS_LOG(DEBUG) << "Enable mindRT.";
  580. context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, true);
  581. }
  582. } // namespace compile
  583. } // namespace mindspore