You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/transform.h"
  19. #include <algorithm>
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include "pipeline/static_analysis/abstract_value.h"
  24. #ifdef ENABLE_GE
  25. #include "transform/convert.h"
  26. #endif
  27. #include "utils/graph_utils.h"
  28. #include "utils/context/ms_context.h"
  29. #include "debug/trace.h"
  30. namespace mindspore {
  31. namespace compile {
  32. using mindspore::abstract::AbstractFunction;
  33. using mindspore::abstract::AbstractFunctionPtr;
  34. using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
  35. using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
  36. using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
  37. std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  38. prim::kPrimMakeTuple};
  39. const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
  40. static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch};
  41. return ms_nonlinear_ops;
  42. }
  43. CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
  44. : backend_(backend), cut_list_(cut_list) {
  45. MS_EXCEPTION_IF_NULL(backend_);
  46. lin_convert_ = backend_->convert_fn();
  47. if (lin_convert_ == nullptr) {
  48. MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
  49. }
  50. is_gevm_convert_ = false;
  51. if (backend->name() == kGeVm) {
  52. MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true";
  53. is_gevm_convert_ = true;
  54. }
  55. }
  56. bool CompileGraph::IsCut(const AnfNodePtr &node) {
  57. MS_EXCEPTION_IF_NULL(node);
  58. if (node->isa<CNode>()) {
  59. auto cnode = node->cast<CNodePtr>();
  60. auto &inputs = cnode->inputs();
  61. if (inputs.empty()) {
  62. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  63. }
  64. AnfNodePtr fn = inputs[0];
  65. if (!IsValueNode<Primitive>(fn)) {
  66. return true;
  67. }
  68. PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
  69. for (auto &prim : cut_list_) {
  70. MS_EXCEPTION_IF_NULL(prim);
  71. if (prim->name() == node_prim->name()) {
  72. return true;
  73. }
  74. }
  75. #ifdef ENABLE_GE
  76. if (is_gevm_convert_) {
  77. auto name = GetCNodeFuncName(cnode);
  78. auto adpt = transform::DfGraphConvertor::FindAdapter(name);
  79. if (adpt == nullptr) {
  80. return true;
  81. }
  82. }
  83. #endif
  84. }
  85. return false;
  86. }
  87. VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
  88. MS_EXCEPTION_IF_NULL(graph);
  89. VectorRef splits;
  90. VectorRef split;
  91. std::vector<AnfNodePtr> nodes = TopoSort(graph->get_return());
  92. MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
  93. for (auto &node : nodes) {
  94. MS_EXCEPTION_IF_NULL(node);
  95. if (IsCut(node)) {
  96. MS_LOG(DEBUG) << "Cut node:" << node->DebugString(10) << ", size:" << split.size();
  97. if (split.size() != 0) {
  98. splits.push_back(split);
  99. }
  100. splits.push_back(node);
  101. split.clear();
  102. } else if (!(node->isa<ValueNode>() || node->isa<Parameter>())) {
  103. split.push_back(node);
  104. MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size();
  105. }
  106. }
  107. MS_LOG(DEBUG) << "Split node size :" << splits.size();
  108. return splits;
  109. }
  110. // Push the value node on the stack.
  111. void CompileGraph::Push(const AnfNodePtr &node) {
  112. MS_EXCEPTION_IF_NULL(node);
  113. if (slots_.count(node) > 0) {
  114. MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString()
  115. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  116. }
  117. MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
  118. << " is parameter: " << node->isa<Parameter>();
  119. slots_[node] = height_;
  120. set_height(height_ + 1);
  121. }
  122. void CompileGraph::AddInst(const Instruction &inst, const int &arg) {
  123. VectorRef args;
  124. args.push_back(arg);
  125. AddInst(inst, args);
  126. }
  127. void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
  128. VectorRef args;
  129. args.push_back(arg);
  130. AddInst(inst, args);
  131. }
  132. void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
  133. inst_.push_back(std::make_pair(inst, args));
  134. }
  135. // Gets the stack reference for the node value. If the node is a constant,
  136. // it may actually cause the push in to not be mentioned before.
  137. int CompileGraph::Ref(const AnfNodePtr &node) {
  138. MS_EXCEPTION_IF_NULL(node);
  139. MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
  140. if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
  141. if (IsValueNode<FuncGraph>(node)) {
  142. MS_LOG(DEBUG) << "Push graph.";
  143. AddInst(Instruction::kGraph, GetValueNode(node));
  144. } else {
  145. MS_LOG(DEBUG) << "Push.";
  146. if (IsValueNode<Primitive>(node)) {
  147. MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  148. } else {
  149. AddInst(Instruction::kPush, GetValueNode(node));
  150. }
  151. }
  152. Push(node);
  153. }
  154. MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
  155. << ", return: " << slots_[node] - height_;
  156. return slots_[node] - height_;
  157. }
  158. // Make sure the value of node is at the top of the stack.
  159. void CompileGraph::AddInput(const AnfNodePtr &node) {
  160. MS_EXCEPTION_IF_NULL(node);
  161. if (slots_.count(node) == 0) {
  162. MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
  163. (void)Ref(node);
  164. return;
  165. }
  166. AddInst(Instruction::kInput, Ref(node));
  167. set_height(height_ + 1);
  168. }
  169. // Call back effect in stack
  170. void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); }
  171. void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
  172. MS_EXCEPTION_IF_NULL(graph);
  173. std::vector<AnfNodePtr> parameters = graph->parameters();
  174. for (size_t i = parameters.size(); i != 0; i--) {
  175. Push(parameters[i - 1]);
  176. MS_LOG(DEBUG) << "Push parameter " << i - 1 << ": " << parameters[i - 1]->DebugString(true);
  177. }
  178. }
  179. int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) {
  180. MS_LOG(DEBUG) << "LinConvert start";
  181. LinConvertResult result;
  182. if (backend_->simu_flag()) {
  183. result = backend_->GetMultiGraphRun(graph);
  184. } else {
  185. result = lin_convert_(node_list);
  186. }
  187. if (result.run == nullptr) {
  188. MS_LOG(ERROR) << "LinConvert failed";
  189. return RET_FAILED;
  190. }
  191. if (!(*result.run)) {
  192. if (result.inputs.size() != result.outputs.size()) {
  193. MS_EXCEPTION_IF_NULL(graph);
  194. MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  195. } else {
  196. size_t size = result.inputs.size();
  197. for (size_t i = 0; i < size; i++) {
  198. Tie(result.inputs[i], result.outputs[i]);
  199. }
  200. return RET_CONTINUE;
  201. }
  202. }
  203. AddExternal(result);
  204. for (auto &o : result.outputs) {
  205. Push(o);
  206. }
  207. return RET_SUCCESS;
  208. }
  209. void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
  210. MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
  211. if (backend_->is_multi_graph_sink()) {
  212. VectorRef args;
  213. args.emplace_back(-1);
  214. MS_LOG(DEBUG) << "call::" << height_;
  215. AddInst(Instruction::kCall, args);
  216. args.clear();
  217. args.emplace_back(true);
  218. AddInst(Instruction::kSwitchReturn, args);
  219. args.clear();
  220. args.emplace_back(false);
  221. args.emplace_back(Ref(node->input(1)));
  222. args.emplace_back(Ref(node->input(2)));
  223. args.emplace_back(Ref(node->input(3)));
  224. AddInst(Instruction::kSwitch, args);
  225. }
  226. }
  227. int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
  228. MS_EXCEPTION_IF_NULL(node);
  229. MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
  230. std::vector<AnfNodePtr> node_inputs = node->inputs();
  231. if (node_inputs.empty()) {
  232. MS_LOG(EXCEPTION) << "The node->inputs() is empty";
  233. }
  234. AnfNodePtr fn = node_inputs[0];
  235. if (IsValueNode<Primitive>(fn)) {
  236. PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
  237. MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
  238. for (size_t i = node_inputs.size() - 1; i > 0; i--) {
  239. AddInput(node->input(i));
  240. }
  241. if (IsPrimitive(fn, prim::kPrimReturn)) {
  242. AddReturn(node);
  243. return RET_BREAK;
  244. }
  245. if (IsPrimitive(fn, prim::kPrimPartial)) {
  246. AddPartial(node);
  247. } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
  248. AddSwitch(node);
  249. AddSinkSwitch(node);
  250. } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
  251. AddMakeTuple(node);
  252. } else {
  253. AddPrimitive(node, value);
  254. }
  255. } else {
  256. int ret = AddCall(graph, node);
  257. if (ret == RET_BREAK) {
  258. return ret;
  259. }
  260. }
  261. Push(node);
  262. return RET_SUCCESS;
  263. }
  264. void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
  265. auto ret = LinConvert(graph, {});
  266. if (ret == RET_FAILED) {
  267. MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
  268. }
  269. AddReturn(nullptr);
  270. }
  271. bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
  272. MS_LOG(DEBUG) << "Start split graph";
  273. MS_EXCEPTION_IF_NULL(graph);
  274. VectorRef splits = SplitNodes(graph);
  275. MS_LOG(DEBUG) << "Split nodes size:" << splits.size();
  276. for (auto &split : splits) {
  277. int ret = RET_SUCCESS;
  278. if (utils::isa<VectorRef>(split)) {
  279. MS_LOG(DEBUG) << "Start a extern LinConvert";
  280. std::vector<AnfNodePtr> args;
  281. auto vec_ref = utils::cast<VectorRef>(split);
  282. (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
  283. [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
  284. ret = LinConvert(graph, args);
  285. MS_LOG(DEBUG) << "End a extern LinConvert";
  286. if (ret == RET_FAILED) {
  287. return false;
  288. }
  289. if (ret == RET_CONTINUE) {
  290. continue;
  291. }
  292. } else {
  293. MS_LOG(DEBUG) << "Start a cut node";
  294. if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) {
  295. MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  296. }
  297. CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>();
  298. ret = InterpretNode(graph, node);
  299. MS_LOG(DEBUG) << "End a cut node";
  300. if (ret == RET_BREAK) {
  301. break;
  302. }
  303. }
  304. }
  305. MS_LOG(DEBUG) << "End split graph";
  306. return true;
  307. }
  308. InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
  309. InstSet inst = Run(graph);
  310. return inst;
  311. }
  312. InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
  313. MS_EXCEPTION_IF_NULL(graph);
  314. MS_LOG(DEBUG) << "Compile start graph: " << graph->ToString();
  315. Reset();
  316. PushParameters(graph);
  317. int param_height = height_;
  318. MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
  319. if (backend_->simu_flag()) {
  320. GenMultiGraphsRun(graph);
  321. } else {
  322. if (!SplitGraph(graph)) {
  323. return inst_;
  324. }
  325. }
  326. AddPadStack(param_height);
  327. auto ret = inst_;
  328. Reset();
  329. return ret;
  330. }
  331. void CompileGraph::AddPadStack(int param_height) {
  332. int stack_sizes = max_height_ - param_height;
  333. MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
  334. << " need_stack:" << stack_sizes;
  335. if (stack_sizes > 0) {
  336. VectorRef need_stacks({stack_sizes});
  337. (void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
  338. }
  339. }
  340. void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
  341. VectorRef args;
  342. args.emplace_back(Ref(fn));
  343. args.emplace_back(height_);
  344. args.emplace_back(static_cast<int>(size - 1));
  345. MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  346. AddInst(Instruction::kTailCall, args);
  347. }
  348. void CompileGraph::AddPartial(const CNodePtr &node) {
  349. auto inputs = node->inputs();
  350. VectorRef args;
  351. auto fn = inputs[1];
  352. if (!IsValueNode<FuncGraph>(fn)) {
  353. MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
  354. }
  355. if (backend_->is_multi_graph_sink()) {
  356. auto func_graph = GetValueNode<FuncGraphPtr>(fn);
  357. args.emplace_back(func_graph);
  358. AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
  359. backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
  360. }
  361. for (size_t i = 1; i < inputs.size(); i++) {
  362. args.emplace_back(Ref(inputs[i]));
  363. }
  364. AddInst(Instruction::kPartial, args);
  365. }
  366. void CompileGraph::AddMakeTuple(const CNodePtr &node) {
  367. auto inputs = node->inputs();
  368. VectorRef args;
  369. for (size_t i = 1; i < inputs.size(); i++) {
  370. args.emplace_back(Ref(inputs[i]));
  371. }
  372. AddInst(Instruction::kTuple, args);
  373. }
  374. void CompileGraph::AddSwitch(const CNodePtr &node) {
  375. auto inputs = node->inputs();
  376. if (inputs.size() < 4) {
  377. MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
  378. }
  379. VectorRef args;
  380. if (backend_->is_multi_graph_sink()) {
  381. args.emplace_back(true);
  382. }
  383. args.emplace_back(Ref(inputs[1]));
  384. args.emplace_back(Ref(inputs[2]));
  385. args.emplace_back(Ref(inputs[3]));
  386. AddInst(Instruction::kSwitch, args);
  387. }
  388. void CompileGraph::AddReturn(const CNodePtr &node) {
  389. VectorRef args;
  390. if (backend_->simu_flag()) {
  391. args.emplace_back(Ref(backend_->final_output()));
  392. } else {
  393. args.emplace_back(Ref(node->input(1)));
  394. }
  395. args.emplace_back(height_);
  396. AddInst(Instruction::kReturn, args);
  397. }
  398. void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
  399. auto inputs = node->inputs();
  400. VectorRef args;
  401. args.push_back(prim);
  402. for (size_t i = 1; i < inputs.size(); i++) {
  403. args.emplace_back(Ref(inputs[i]));
  404. }
  405. AddInst(Instruction::kPrim, args);
  406. }
  407. int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
  408. auto inputs = node->inputs();
  409. AnfNodePtr fn = inputs[0];
  410. if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
  411. auto func_graph = GetValueNode<FuncGraphPtr>(fn);
  412. AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
  413. backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
  414. }
  415. (void)Ref(fn);
  416. size_t size = inputs.size();
  417. for (size_t i = size - 1; i > 0; i--) {
  418. AddInput(inputs[i]);
  419. }
  420. if (node == graph->output()) {
  421. AddTailCall(fn, size);
  422. return RET_BREAK;
  423. }
  424. MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  425. AddInst(Instruction::kCall, Ref(fn));
  426. Ret(static_cast<int>(size - 1));
  427. return RET_SUCCESS;
  428. }
  429. void CompileGraph::AddExternal(const LinConvertResult &result) {
  430. VectorRef args;
  431. args.push_back(result.run);
  432. args.push_back(result.simu_run);
  433. size_t size = result.inputs.size();
  434. for (size_t i = 0; i < size; i++) {
  435. args.emplace_back(Ref(result.inputs[i]));
  436. }
  437. AddInst(Instruction::kExternal, args);
  438. }
  439. void TraverseGraphMap(
  440. const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts,
  441. const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
  442. MS_EXCEPTION_IF_NULL(manager_ptr);
  443. MS_EXCEPTION_IF_NULL(tr);
  444. for (const auto &ct_graphs : cts) {
  445. for (const auto &ct_any : ct_graphs.second) {
  446. AnfNodePtr const_primitive_node = ct_any.first;
  447. if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
  448. auto users = manager_ptr->node_users()[const_primitive_node];
  449. for (auto &use : users) {
  450. CNodePtr node = use.first->cast<CNodePtr>();
  451. MS_EXCEPTION_IF_NULL(node);
  452. int key = use.second;
  453. if (key != 0) {
  454. MS_EXCEPTION_IF_NULL(node->input(0));
  455. bool key_is_const = node->input(0)->isa<ValueNode>();
  456. PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
  457. bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
  458. bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
  459. if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
  460. continue;
  461. }
  462. FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
  463. dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
  464. tr->SetEdge(node, key, NewValueNode(g));
  465. }
  466. }
  467. }
  468. }
  469. }
  470. }
  471. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
  472. MS_EXCEPTION_IF_NULL(graph);
  473. FuncGraphManagerPtr manager_ptr = graph->manager();
  474. MS_EXCEPTION_IF_NULL(manager_ptr);
  475. MapPrimTypeFuncGraph prim_graphs;
  476. auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
  477. PrimTypePair prim_type = std::make_pair(prim, type);
  478. if (prim_graphs.end() == prim_graphs.find(prim_type)) {
  479. FuncGraphPtr g = std::make_shared<FuncGraph>();
  480. std::vector<AnfNodePtr> args;
  481. ValueNodePtr prim_ct = NewValueNode(prim);
  482. MS_EXCEPTION_IF_NULL(prim_ct);
  483. prim_ct->set_abstract(type);
  484. args.push_back(prim_ct);
  485. MS_EXCEPTION_IF_NULL(type);
  486. TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
  487. MS_EXCEPTION_IF_NULL(tp);
  488. MS_EXCEPTION_IF_NULL(g);
  489. for (auto t : tp->args_spec_list()) {
  490. ParameterPtr p = g->add_parameter();
  491. p->set_abstract(t);
  492. args.push_back(p);
  493. }
  494. AnfNodePtr out = g->NewCNode(args);
  495. out->set_abstract(tp->output());
  496. g->set_output(out);
  497. prim_graphs[prim_type] = g;
  498. }
  499. return prim_graphs[prim_type];
  500. };
  501. FuncGraphTransaction tr = manager_ptr->Transact();
  502. auto &cts = manager_ptr->valuenodes();
  503. TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph);
  504. return graph;
  505. }
  506. CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
  507. MS_EXCEPTION_IF_NULL(backend);
  508. MS_LOG(DEBUG) << "Start vm: " << backend->name();
  509. transform_ = std::make_shared<CompileGraph>(backend, cut_list);
  510. Reset();
  511. }
  512. // Convert graphs to unlinked instructions.
  513. void CompileGraphs::Compile(const FuncGraphPtr &graph) {
  514. MS_LOG(DEBUG) << "Start";
  515. auto graph_manager = graph->manager();
  516. MS_EXCEPTION_IF_NULL(graph_manager);
  517. FuncGraphSet graphs = graph_manager->func_graphs();
  518. for (auto &g : graphs) {
  519. mapping_[g] = static_cast<int>(insts_.size());
  520. if (transform_ != nullptr) {
  521. InstSet insts = transform_->Run(g);
  522. if (!insts.empty()) {
  523. (void)insts_.insert(insts_.end(), insts.begin(), insts.end());
  524. }
  525. }
  526. }
  527. MS_LOG(DEBUG) << "End";
  528. }
  529. // Link instructions from multiple function graphs together.
  530. FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
  531. MS_LOG(DEBUG) << "Start";
  532. for (std::size_t i = 0; i < insts_.size(); i++) {
  533. InstType inst = insts_[i];
  534. MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
  535. if (Instruction::kGraph == inst.first) {
  536. if (inst.second.empty()) {
  537. MS_LOG(EXCEPTION) << "The second element of inst is empty";
  538. }
  539. FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
  540. MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
  541. insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
  542. }
  543. }
  544. FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
  545. if (backend_->is_multi_graph_sink()) {
  546. backend_->set_simu_flag(true);
  547. MS_LOG(DEBUG) << "Start simulate";
  548. backend_->SimulateRun(rt, graph);
  549. MS_LOG(DEBUG) << "Link graphs";
  550. insts_ = transform_->GenMultiGraphsSinkInst(graph);
  551. rt->set_insts(insts_);
  552. backend_->set_simu_flag(false);
  553. MS_LOG(DEBUG) << "End start simulate";
  554. backend_->Link(kInvalidGraphId);
  555. }
  556. MS_LOG(DEBUG) << "End";
  557. return rt;
  558. }
  559. // Convert all graphs to unlinked instructions and link them.
  560. FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
  561. MS_EXCEPTION_IF_NULL(graph);
  562. MS_LOG(DEBUG) << "Start";
  563. Reset();
  564. MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
  565. (void)WrapPrimitives(graph);
  566. Compile(graph);
  567. FinalVMPtr rt = Link(graph);
  568. Reset();
  569. MS_LOG(DEBUG) << "End";
  570. return rt;
  571. }
  572. BackendPtr CreateBackend() {
  573. auto context_ptr = MsContext::GetInstance();
  574. MS_EXCEPTION_IF_NULL(context_ptr);
  575. std::string name = context_ptr->backend_policy();
  576. MS_LOG(INFO) << "CreateBackend is: " << name;
  577. if (backend_list.count(name) == 0) {
  578. MS_LOG(EXCEPTION) << "Backend is error: " << name;
  579. }
  580. if (name == kMsConvert) {
  581. std::string target = context_ptr->device_target();
  582. uint32_t device_id = context_ptr->device_id();
  583. auto backend = std::make_shared<MsBackend>(name, target, device_id);
  584. std::string device_target = MsContext::GetInstance()->device_target();
  585. if (device_target == kAscendDevice) {
  586. backend->set_is_multi_graph_sink(true);
  587. context_ptr->set_is_multi_graph_sink(true);
  588. }
  589. return backend;
  590. }
  591. return std::make_shared<Backend>(name);
  592. }
  593. } // namespace compile
  594. } // namespace mindspore