You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.cc 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/transform.h"
  19. #include <algorithm>
  20. #include <map>
  21. #include <queue>
  22. #include <stack>
  23. #include <set>
  24. #include <string>
  25. #include <vector>
  26. #include "pipeline/static_analysis/abstract_value.h"
  27. #ifdef ENABLE_GE
  28. #include "transform/convert.h"
  29. #endif
  30. #include "utils/graph_utils.h"
  31. #include "utils/context/ms_context.h"
  32. #include "debug/trace.h"
  33. #include "debug/anf_ir_dump.h"
  34. namespace mindspore {
  35. namespace compile {
  36. using mindspore::abstract::AbstractFunction;
  37. using mindspore::abstract::AbstractFunctionPtr;
  38. using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
  39. using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
  40. using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
  41. std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  42. prim::kPrimMakeTuple, prim::kPrimBpropCut};
  43. const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
  44. static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  45. prim::kPrimBpropCut};
  46. return ms_nonlinear_ops;
  47. }
  48. namespace {
  49. bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
  50. auto context_ptr = MsContext::GetInstance();
  51. MS_EXCEPTION_IF_NULL(context_ptr);
  52. std::string last_target = context_ptr->device_target();
  53. for (auto &node : nodes) {
  54. if (node->isa<CNode>()) {
  55. std::string cur_target = GetCNodeTarget(node);
  56. if (last_target != cur_target) {
  57. return true;
  58. }
  59. last_target = cur_target;
  60. }
  61. }
  62. return false;
  63. }
  64. void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
  65. std::queue<AnfNodePtr> queue;
  66. queue.push(graph->get_return());
  67. std::set<AnfNodePtr> visited;
  68. while (!queue.empty()) {
  69. auto &node = queue.front();
  70. queue.pop();
  71. MS_EXCEPTION_IF_NULL(node);
  72. if (!node->isa<CNode>()) {
  73. continue;
  74. }
  75. auto cnode = node->cast<CNodePtr>();
  76. MS_EXCEPTION_IF_NULL(cnode);
  77. for (auto &input : cnode->inputs()) {
  78. auto iter = nodes_ref->find(input);
  79. if (iter != nodes_ref->end()) {
  80. iter->second++;
  81. } else {
  82. (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
  83. }
  84. if (visited.find(input) != visited.end()) {
  85. continue;
  86. }
  87. visited.insert(input);
  88. queue.push(input);
  89. }
  90. }
  91. }
  92. std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
  93. std::vector<AnfNodePtr> result;
  94. std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
  95. std::map<AnfNodePtr, size_t> node_positions;
  96. for (auto &node : nodes) {
  97. if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  98. auto cnode = node->cast<CNodePtr>();
  99. MS_EXCEPTION_IF_NULL(cnode);
  100. auto &inputs = cnode->inputs();
  101. if (inputs.size() < 2) {
  102. MS_LOG(EXCEPTION) << "Invalid get item node";
  103. }
  104. auto &parent = inputs[1];
  105. auto iter = node_positions.find(parent);
  106. if (iter != node_positions.end()) {
  107. size_t position = iter->second;
  108. auto iter_nodes = insert_positions.find(position);
  109. if (iter_nodes != insert_positions.end()) {
  110. iter_nodes->second.push_back(node);
  111. } else {
  112. (void)insert_positions.insert(
  113. std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
  114. }
  115. continue;
  116. }
  117. }
  118. result.emplace_back(node);
  119. node_positions[node] = result.size();
  120. }
  121. size_t insert_num = 0;
  122. for (auto &item : insert_positions) {
  123. size_t position = item.first + insert_num;
  124. (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
  125. insert_num += item.second.size();
  126. }
  127. return result;
  128. }
  129. std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
  130. std::vector<AnfNodePtr> result;
  131. std::stack<AnfNodePtr> to_visit;
  132. std::stack<AnfNodePtr> next_to_visit;
  133. std::map<AnfNodePtr, size_t> nodes_ref;
  134. CalcNodeRefCount(graph, &nodes_ref);
  135. std::string handle_target = default_target;
  136. std::string next_target = "";
  137. to_visit.push(graph->get_return());
  138. while (!to_visit.empty() || !next_to_visit.empty()) {
  139. if (to_visit.empty()) {
  140. to_visit.swap(next_to_visit);
  141. handle_target = next_target;
  142. }
  143. auto &node = to_visit.top();
  144. MS_EXCEPTION_IF_NULL(node);
  145. to_visit.pop();
  146. result.emplace_back(node);
  147. if (!node->isa<CNode>()) {
  148. continue;
  149. }
  150. auto cnode = node->cast<CNodePtr>();
  151. MS_EXCEPTION_IF_NULL(cnode);
  152. auto node_inputs = cnode->inputs();
  153. std::reverse(node_inputs.begin(), node_inputs.end());
  154. for (auto &input : node_inputs) {
  155. auto iter = nodes_ref.find(input);
  156. if (iter != nodes_ref.end()) {
  157. iter->second--;
  158. if (iter->second != 0) {
  159. continue;
  160. }
  161. }
  162. if (!input->isa<CNode>()) {
  163. to_visit.push(input);
  164. continue;
  165. }
  166. std::string input_target = GetCNodeTarget(input);
  167. if (input_target == handle_target) {
  168. to_visit.push(input);
  169. } else if (next_to_visit.empty() || input_target == next_target) {
  170. next_to_visit.push(input);
  171. next_target = input_target;
  172. } else {
  173. MS_LOG(EXCEPTION) << "only support two different target";
  174. }
  175. }
  176. }
  177. std::reverse(result.begin(), result.end());
  178. return result;
  179. }
  180. } // namespace
  181. CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
  182. : backend_(backend), cut_list_(cut_list) {
  183. MS_EXCEPTION_IF_NULL(backend_);
  184. lin_convert_ = backend_->convert_fn();
  185. if (lin_convert_ == nullptr) {
  186. MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
  187. }
  188. is_gevm_convert_ = false;
  189. if (backend->name() == kGeVm) {
  190. MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true";
  191. is_gevm_convert_ = true;
  192. }
  193. }
  194. bool CompileGraph::IsCut(const AnfNodePtr &node) {
  195. MS_EXCEPTION_IF_NULL(node);
  196. if (node->isa<CNode>()) {
  197. auto cnode = node->cast<CNodePtr>();
  198. auto &inputs = cnode->inputs();
  199. if (inputs.empty()) {
  200. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  201. }
  202. AnfNodePtr fn = inputs[0];
  203. MS_EXCEPTION_IF_NULL(fn);
  204. if (IsValueNode<FuncGraph>(fn)) {
  205. auto fg = GetValueNode<FuncGraphPtr>(fn);
  206. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  207. return false;
  208. }
  209. }
  210. if (!IsValueNode<Primitive>(fn)) {
  211. return true;
  212. }
  213. PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
  214. for (auto &prim : cut_list_) {
  215. MS_EXCEPTION_IF_NULL(prim);
  216. if (prim->name() == node_prim->name()) {
  217. if (prim->name() == prim::kPrimBpropCut->name()) {
  218. auto ms_context = MsContext::GetInstance();
  219. MS_EXCEPTION_IF_NULL(ms_context);
  220. ms_context->set_enable_pynative_hook(true);
  221. }
  222. return true;
  223. }
  224. }
  225. #ifdef ENABLE_GE
  226. if (is_gevm_convert_) {
  227. auto name = GetCNodeFuncName(cnode);
  228. auto adpt = transform::DfGraphConvertor::FindAdapter(name);
  229. if (adpt == nullptr) {
  230. return true;
  231. }
  232. }
  233. #endif
  234. }
  235. return false;
  236. }
  237. VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
  238. MS_EXCEPTION_IF_NULL(graph);
  239. auto nodes = OptimizeGetItemOrder(input_nodes);
  240. VectorRef splits;
  241. VectorRef split;
  242. std::string last_target;
  243. for (auto &node : nodes) {
  244. MS_EXCEPTION_IF_NULL(node);
  245. if (IsCut(node)) {
  246. if (split.size() != 0) {
  247. splits.push_back(split);
  248. }
  249. splits.push_back(node);
  250. split.clear();
  251. } else if (node->isa<CNode>()) {
  252. std::string cur_target = GetCNodeTarget(node);
  253. if (cur_target != last_target && !last_target.empty() && split.size() != 0) {
  254. splits.push_back(split);
  255. split.clear();
  256. }
  257. last_target = cur_target;
  258. split.push_back(node);
  259. }
  260. }
  261. return splits;
  262. }
  263. VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
  264. MS_EXCEPTION_IF_NULL(graph);
  265. auto nodes = TopoSort(graph->get_return());
  266. MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
  267. if (ContainMultiTarget(nodes)) {
  268. auto context_ptr = MsContext::GetInstance();
  269. MS_EXCEPTION_IF_NULL(context_ptr);
  270. std::string default_target = context_ptr->device_target();
  271. nodes = SplitSort(graph, default_target);
  272. return SplitNodesWithTarget(nodes, graph);
  273. }
  274. VectorRef splits;
  275. VectorRef split;
  276. for (auto &node : nodes) {
  277. MS_EXCEPTION_IF_NULL(node);
  278. if (IsCut(node)) {
  279. if (split.size() != 0) {
  280. splits.push_back(split);
  281. }
  282. splits.push_back(node);
  283. split.clear();
  284. } else if (node->isa<CNode>()) {
  285. split.push_back(node);
  286. }
  287. }
  288. return splits;
  289. }
  290. // Push the value node on the stack.
  291. void CompileGraph::Push(const AnfNodePtr &node) {
  292. MS_EXCEPTION_IF_NULL(node);
  293. if (slots_.count(node) > 0) {
  294. MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString()
  295. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  296. }
  297. MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
  298. << " is parameter: " << node->isa<Parameter>();
  299. slots_[node] = height_;
  300. set_height(height_ + 1);
  301. }
  302. void CompileGraph::AddInst(const Instruction &inst, const int &arg) {
  303. VectorRef args;
  304. args.push_back(arg);
  305. AddInst(inst, args);
  306. }
  307. void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
  308. VectorRef args;
  309. args.push_back(arg);
  310. AddInst(inst, args);
  311. }
  312. void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
  313. inst_.push_back(std::make_pair(inst, args));
  314. }
  315. // Gets the stack reference for the node value. If the node is a constant,
  316. // it may actually cause the push in to not be mentioned before.
  317. int CompileGraph::Ref(const AnfNodePtr &node) {
  318. MS_EXCEPTION_IF_NULL(node);
  319. MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
  320. if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
  321. if (IsValueNode<FuncGraph>(node)) {
  322. MS_LOG(DEBUG) << "Push graph.";
  323. AddInst(Instruction::kGraph, GetValueNode(node));
  324. } else {
  325. MS_LOG(DEBUG) << "Push.";
  326. if (IsValueNode<Primitive>(node)) {
  327. MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  328. } else {
  329. AddInst(Instruction::kPush, GetValueNode(node));
  330. }
  331. }
  332. Push(node);
  333. }
  334. MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
  335. << ", return: " << slots_[node] - height_;
  336. return slots_[node] - height_;
  337. }
  338. // Make sure the value of node is at the top of the stack.
  339. void CompileGraph::AddInput(const AnfNodePtr &node) {
  340. MS_EXCEPTION_IF_NULL(node);
  341. if (slots_.count(node) == 0) {
  342. MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
  343. (void)Ref(node);
  344. return;
  345. }
  346. AddInst(Instruction::kInput, Ref(node));
  347. set_height(height_ + 1);
  348. }
  349. // Call back effect in stack
  350. void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); }
  351. void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
  352. MS_EXCEPTION_IF_NULL(graph);
  353. std::vector<AnfNodePtr> parameters = graph->parameters();
  354. for (size_t i = parameters.size(); i != 0; i--) {
  355. Push(parameters[i - 1]);
  356. MS_LOG(DEBUG) << "Push parameter " << i - 1 << ": " << parameters[i - 1]->DebugString(true);
  357. }
  358. }
  359. int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) {
  360. MS_LOG(DEBUG) << "LinConvert start";
  361. LinConvertResult result;
  362. if (backend_->simu_flag()) {
  363. result = backend_->GetMultiGraphRun(graph);
  364. } else {
  365. result = lin_convert_(node_list, target);
  366. }
  367. if (result.run == nullptr) {
  368. MS_LOG(ERROR) << "LinConvert failed";
  369. return RET_FAILED;
  370. }
  371. if (!(*result.run)) {
  372. if (result.inputs.size() != result.outputs.size()) {
  373. MS_EXCEPTION_IF_NULL(graph);
  374. MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  375. } else {
  376. size_t size = result.inputs.size();
  377. for (size_t i = 0; i < size; i++) {
  378. Tie(result.inputs[i], result.outputs[i]);
  379. }
  380. return RET_CONTINUE;
  381. }
  382. }
  383. AddExternal(result);
  384. for (auto &o : result.outputs) {
  385. Push(o);
  386. }
  387. return RET_SUCCESS;
  388. }
  389. void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
  390. MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
  391. if (backend_->is_multi_graph_sink()) {
  392. VectorRef args;
  393. args.emplace_back(-1);
  394. MS_LOG(DEBUG) << "call::" << height_;
  395. AddInst(Instruction::kCall, args);
  396. args.clear();
  397. args.emplace_back(node->input(1));
  398. AddInst(Instruction::kSwitchReturn, args);
  399. args.clear();
  400. args.emplace_back(false);
  401. args.emplace_back(Ref(node->input(1)));
  402. args.emplace_back(Ref(node->input(2)));
  403. args.emplace_back(Ref(node->input(3)));
  404. AddInst(Instruction::kSwitch, args);
  405. }
  406. }
  407. int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
  408. MS_EXCEPTION_IF_NULL(node);
  409. MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
  410. std::vector<AnfNodePtr> node_inputs = node->inputs();
  411. if (node_inputs.empty()) {
  412. MS_LOG(EXCEPTION) << "The node->inputs() is empty";
  413. }
  414. AnfNodePtr fn = node_inputs[0];
  415. if (IsValueNode<Primitive>(fn)) {
  416. PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
  417. MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
  418. for (size_t i = node_inputs.size() - 1; i > 0; i--) {
  419. AddInput(node->input(i));
  420. }
  421. if (IsPrimitive(fn, prim::kPrimReturn)) {
  422. AddReturn(node);
  423. return RET_BREAK;
  424. }
  425. if (IsPrimitive(fn, prim::kPrimPartial)) {
  426. AddPartial(node);
  427. } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
  428. AddSwitch(node);
  429. AddSinkSwitch(node);
  430. } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
  431. AddMakeTuple(node);
  432. } else {
  433. AddPrimitive(node, value);
  434. }
  435. } else {
  436. int ret = AddCall(graph, node);
  437. if (ret == RET_BREAK) {
  438. return ret;
  439. }
  440. }
  441. Push(node);
  442. return RET_SUCCESS;
  443. }
  444. void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
  445. auto ret = LinConvert(graph, {});
  446. if (ret == RET_FAILED) {
  447. MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
  448. }
  449. AddReturn(nullptr);
  450. }
  451. bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
  452. MS_LOG(DEBUG) << "Start split graph";
  453. MS_EXCEPTION_IF_NULL(graph);
  454. VectorRef splits = SplitNodes(graph);
  455. MS_LOG(DEBUG) << "Split nodes size:" << splits.size();
  456. for (auto &split : splits) {
  457. int ret = RET_SUCCESS;
  458. if (utils::isa<VectorRef>(split)) {
  459. MS_LOG(DEBUG) << "Start a extern LinConvert";
  460. std::vector<AnfNodePtr> args;
  461. auto vec_ref = utils::cast<VectorRef>(split);
  462. (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
  463. [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
  464. if (args.size() > 0) {
  465. std::string cur_target = GetCNodeTarget(args[0]);
  466. ret = LinConvert(graph, args, cur_target);
  467. } else {
  468. ret = LinConvert(graph, args);
  469. }
  470. MS_LOG(DEBUG) << "End a extern LinConvert";
  471. if (ret == RET_FAILED) {
  472. return false;
  473. }
  474. if (ret == RET_CONTINUE) {
  475. continue;
  476. }
  477. } else {
  478. MS_LOG(DEBUG) << "Start a cut node";
  479. if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) {
  480. MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  481. }
  482. CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>();
  483. ret = InterpretNode(graph, node);
  484. MS_LOG(DEBUG) << "End a cut node";
  485. if (ret == RET_BREAK) {
  486. break;
  487. }
  488. }
  489. }
  490. MS_LOG(DEBUG) << "End split graph";
  491. return true;
  492. }
  493. InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
  494. InstSet inst = Run(graph);
  495. return inst;
  496. }
  497. InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
  498. MS_EXCEPTION_IF_NULL(graph);
  499. Reset();
  500. PushParameters(graph);
  501. int param_height = height_;
  502. MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
  503. if (backend_->simu_flag()) {
  504. GenMultiGraphsRun(graph);
  505. } else {
  506. if (!SplitGraph(graph)) {
  507. return inst_;
  508. }
  509. }
  510. AddPadStack(param_height);
  511. auto ret = inst_;
  512. Reset();
  513. return ret;
  514. }
  515. void CompileGraph::AddPadStack(int param_height) {
  516. int stack_sizes = max_height_ - param_height;
  517. MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
  518. << " need_stack:" << stack_sizes;
  519. if (stack_sizes > 0) {
  520. VectorRef need_stacks({stack_sizes});
  521. (void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
  522. }
  523. }
  524. void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
  525. VectorRef args;
  526. args.emplace_back(Ref(fn));
  527. args.emplace_back(height_);
  528. args.emplace_back(static_cast<int>(size - 1));
  529. MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  530. AddInst(Instruction::kTailCall, args);
  531. }
  532. void CompileGraph::AddPartial(const CNodePtr &node) {
  533. auto inputs = node->inputs();
  534. VectorRef args;
  535. auto fn = inputs[1];
  536. if (!IsValueNode<FuncGraph>(fn)) {
  537. MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
  538. }
  539. if (backend_->is_multi_graph_sink()) {
  540. auto func_graph = GetValueNode<FuncGraphPtr>(fn);
  541. args.emplace_back(func_graph);
  542. AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
  543. backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
  544. }
  545. for (size_t i = 1; i < inputs.size(); i++) {
  546. args.emplace_back(Ref(inputs[i]));
  547. }
  548. AddInst(Instruction::kPartial, args);
  549. }
  550. void CompileGraph::AddMakeTuple(const CNodePtr &node) {
  551. auto inputs = node->inputs();
  552. VectorRef args;
  553. for (size_t i = 1; i < inputs.size(); i++) {
  554. args.emplace_back(Ref(inputs[i]));
  555. }
  556. AddInst(Instruction::kTuple, args);
  557. }
  558. void CompileGraph::AddSwitch(const CNodePtr &node) {
  559. auto inputs = node->inputs();
  560. if (inputs.size() < 4) {
  561. MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
  562. }
  563. VectorRef args;
  564. if (backend_->is_multi_graph_sink()) {
  565. args.emplace_back(true);
  566. }
  567. args.emplace_back(Ref(inputs[1]));
  568. args.emplace_back(Ref(inputs[2]));
  569. args.emplace_back(Ref(inputs[3]));
  570. AddInst(Instruction::kSwitch, args);
  571. }
  572. void CompileGraph::AddReturn(const CNodePtr &node) {
  573. VectorRef args;
  574. if (backend_->simu_flag()) {
  575. args.emplace_back(Ref(backend_->final_output()));
  576. } else {
  577. args.emplace_back(Ref(node->input(1)));
  578. }
  579. args.emplace_back(height_);
  580. AddInst(Instruction::kReturn, args);
  581. }
  582. void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
  583. auto inputs = node->inputs();
  584. VectorRef args;
  585. args.push_back(prim);
  586. for (size_t i = 1; i < inputs.size(); i++) {
  587. args.emplace_back(Ref(inputs[i]));
  588. }
  589. AddInst(Instruction::kPrim, args);
  590. }
  591. int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
  592. auto inputs = node->inputs();
  593. AnfNodePtr fn = inputs[0];
  594. if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
  595. auto func_graph = GetValueNode<FuncGraphPtr>(fn);
  596. AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
  597. backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
  598. }
  599. (void)Ref(fn);
  600. size_t size = inputs.size();
  601. for (size_t i = size - 1; i > 0; i--) {
  602. AddInput(inputs[i]);
  603. }
  604. if (node == graph->output()) {
  605. AddTailCall(fn, size);
  606. return RET_BREAK;
  607. }
  608. MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  609. AddInst(Instruction::kCall, Ref(fn));
  610. Ret(static_cast<int>(size - 1));
  611. return RET_SUCCESS;
  612. }
  613. void CompileGraph::AddExternal(const LinConvertResult &result) {
  614. VectorRef args;
  615. args.push_back(result.run);
  616. args.push_back(result.simu_run);
  617. size_t size = result.inputs.size();
  618. for (size_t i = 0; i < size; i++) {
  619. args.emplace_back(Ref(result.inputs[i]));
  620. }
  621. AddInst(Instruction::kExternal, args);
  622. }
  623. void TraverseGraphMap(
  624. const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs,
  625. const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
  626. MS_EXCEPTION_IF_NULL(manager_ptr);
  627. MS_EXCEPTION_IF_NULL(tr);
  628. for (const auto &fg : fgs) {
  629. for (const auto &ct_any : fg->value_nodes()) {
  630. AnfNodePtr const_primitive_node = ct_any.first;
  631. if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
  632. auto users = manager_ptr->node_users()[const_primitive_node];
  633. for (auto &use : users) {
  634. CNodePtr node = use.first->cast<CNodePtr>();
  635. MS_EXCEPTION_IF_NULL(node);
  636. if (node->func_graph() != fg) {
  637. continue;
  638. }
  639. int key = use.second;
  640. if (key != 0) {
  641. MS_EXCEPTION_IF_NULL(node->input(0));
  642. bool key_is_const = node->input(0)->isa<ValueNode>();
  643. PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
  644. if (value != nullptr) {
  645. bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
  646. bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
  647. if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
  648. continue;
  649. }
  650. }
  651. FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
  652. dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
  653. tr->SetEdge(node, key, NewValueNode(g));
  654. }
  655. }
  656. }
  657. }
  658. }
  659. }
  660. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
  661. MS_EXCEPTION_IF_NULL(graph);
  662. FuncGraphManagerPtr manager_ptr = graph->manager();
  663. MS_EXCEPTION_IF_NULL(manager_ptr);
  664. MapPrimTypeFuncGraph prim_graphs;
  665. auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
  666. PrimTypePair prim_type = std::make_pair(prim, type);
  667. if (prim_graphs.end() == prim_graphs.find(prim_type)) {
  668. FuncGraphPtr g = std::make_shared<FuncGraph>();
  669. std::vector<AnfNodePtr> args;
  670. ValueNodePtr prim_ct = NewValueNode(prim);
  671. MS_EXCEPTION_IF_NULL(prim_ct);
  672. prim_ct->set_abstract(type);
  673. args.push_back(prim_ct);
  674. MS_EXCEPTION_IF_NULL(type);
  675. TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
  676. MS_EXCEPTION_IF_NULL(tp);
  677. MS_EXCEPTION_IF_NULL(g);
  678. for (auto t : tp->args_spec_list()) {
  679. ParameterPtr p = g->add_parameter();
  680. p->set_abstract(t);
  681. args.push_back(p);
  682. }
  683. AnfNodePtr out = g->NewCNode(args);
  684. out->set_abstract(tp->output());
  685. g->set_output(out);
  686. prim_graphs[prim_type] = g;
  687. }
  688. return prim_graphs[prim_type];
  689. };
  690. FuncGraphTransaction tr = manager_ptr->Transact();
  691. auto &fgs = manager_ptr->func_graphs();
  692. TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
  693. tr.Commit();
  694. return graph;
  695. }
  696. CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
  697. MS_EXCEPTION_IF_NULL(backend);
  698. MS_LOG(DEBUG) << "Start vm: " << backend->name();
  699. transform_ = std::make_shared<CompileGraph>(backend, cut_list);
  700. Reset();
  701. }
  702. // Convert graphs to unlinked instructions.
  703. void CompileGraphs::Compile(const FuncGraphPtr &graph) {
  704. MS_LOG(DEBUG) << "Start";
  705. mapping_[graph] = static_cast<int>(insts_.size());
  706. if (transform_ != nullptr) {
  707. InstSet insts = transform_->Run(graph);
  708. if (!insts.empty()) {
  709. (void)insts_.insert(insts_.end(), insts.begin(), insts.end());
  710. }
  711. }
  712. MS_LOG(DEBUG) << "End";
  713. }
  714. // Link instructions from multiple function graphs together.
  715. FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
  716. MS_LOG(DEBUG) << "Start";
  717. for (std::size_t i = 0; i < insts_.size(); i++) {
  718. InstType inst = insts_[i];
  719. MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
  720. if (Instruction::kGraph == inst.first) {
  721. if (inst.second.empty()) {
  722. MS_LOG(EXCEPTION) << "The second element of inst is empty";
  723. }
  724. FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
  725. MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
  726. insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
  727. }
  728. }
  729. FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
  730. if (backend_->is_multi_graph_sink()) {
  731. backend_->set_simu_flag(true);
  732. MS_LOG(DEBUG) << "Start simulate";
  733. backend_->SimulateRun(rt, graph);
  734. MS_LOG(DEBUG) << "Link graphs";
  735. insts_ = transform_->GenMultiGraphsSinkInst(graph);
  736. rt->set_insts(insts_);
  737. backend_->set_simu_flag(false);
  738. MS_LOG(DEBUG) << "End start simulate";
  739. backend_->Link(kInvalidGraphId);
  740. }
  741. MS_LOG(DEBUG) << "End";
  742. return rt;
  743. }
  744. // Convert all graphs to unlinked instructions and link them.
  745. FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
  746. MS_EXCEPTION_IF_NULL(graph);
  747. MS_LOG(DEBUG) << "Start";
  748. Reset();
  749. MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
  750. FuncGraphPtr prim_graph = WrapPrimitives(graph);
  751. Compile(prim_graph);
  752. MS_EXCEPTION_IF_NULL(prim_graph);
  753. FuncGraphSet graphs = prim_graph->manager()->func_graphs();
  754. for (auto g : graphs) {
  755. if (g != graph && g != nullptr && !(g->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
  756. Compile(g);
  757. }
  758. }
  759. FinalVMPtr rt = Link(graph);
  760. Reset();
  761. MS_LOG(DEBUG) << "End";
  762. return rt;
  763. }
  764. bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
  765. MS_EXCEPTION_IF_NULL(graph);
  766. auto graph_manager = graph->manager();
  767. MS_EXCEPTION_IF_NULL(graph_manager);
  768. FuncGraphSet graphs = graph_manager->func_graphs();
  769. for (auto &g : graphs) {
  770. auto nodes = TopoSort(g->get_return());
  771. if (ContainMultiTarget(nodes)) {
  772. return true;
  773. }
  774. }
  775. return false;
  776. }
  777. BackendPtr CreateBackend() {
  778. auto context_ptr = MsContext::GetInstance();
  779. MS_EXCEPTION_IF_NULL(context_ptr);
  780. std::string name = context_ptr->backend_policy();
  781. MS_LOG(INFO) << "CreateBackend is: " << name;
  782. if (backend_list.count(name) == 0) {
  783. MS_LOG(EXCEPTION) << "Backend is error: " << name;
  784. }
  785. if (name == kMsConvert) {
  786. std::string target = context_ptr->device_target();
  787. uint32_t device_id = context_ptr->device_id();
  788. auto backend = std::make_shared<MsBackend>(name, target, device_id);
  789. std::string device_target = MsContext::GetInstance()->device_target();
  790. if (device_target == kAscendDevice) {
  791. if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
  792. backend->set_is_multi_graph_sink(false);
  793. context_ptr->set_is_multi_graph_sink(false);
  794. } else {
  795. backend->set_is_multi_graph_sink(true);
  796. context_ptr->set_is_multi_graph_sink(true);
  797. }
  798. }
  799. return backend;
  800. }
  801. return std::make_shared<Backend>(name);
  802. }
  803. } // namespace compile
  804. } // namespace mindspore