You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.cc 31 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/transform.h"
  19. #include <algorithm>
  20. #include <map>
  21. #include <queue>
  22. #include <stack>
  23. #include <set>
  24. #include <string>
  25. #include <vector>
  26. #include "abstract/abstract_value.h"
  27. #ifdef ENABLE_GE
  28. #include "transform/graph_ir/convert.h"
  29. #endif
  30. #include "ir/graph_utils.h"
  31. #include "utils/ms_context.h"
  32. #include "debug/trace.h"
  33. #include "debug/anf_ir_dump.h"
  34. namespace mindspore {
  35. namespace compile {
  36. using mindspore::abstract::AbstractFunction;
  37. using mindspore::abstract::AbstractFunctionPtr;
  38. using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
  39. using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
  40. using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
  41. std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  42. prim::kPrimMakeTuple, prim::kPrimBpropCut};
  43. const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
  44. static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
  45. prim::kPrimSwitch, prim::kPrimMakeTuple,
  46. prim::kPrimBpropCut, prim::kPrimSwitchLayer};
  47. return ms_nonlinear_ops;
  48. }
  49. namespace {
  50. bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
  51. auto context_ptr = MsContext::GetInstance();
  52. MS_EXCEPTION_IF_NULL(context_ptr);
  53. std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  54. for (auto &node : nodes) {
  55. if (node->isa<CNode>()) {
  56. std::string cur_target = GetCNodeTarget(node);
  57. if (last_target != cur_target) {
  58. return true;
  59. }
  60. last_target = cur_target;
  61. }
  62. }
  63. return false;
  64. }
  65. bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node,
  66. std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) {
  67. MS_EXCEPTION_IF_NULL(prior_node);
  68. MS_EXCEPTION_IF_NULL(behind_node);
  69. MS_EXCEPTION_IF_NULL(graph);
  70. auto manager = graph->manager();
  71. MS_EXCEPTION_IF_NULL(manager);
  72. auto &node_users = manager->node_users();
  73. if (prior_node->isa<Parameter>()) {
  74. for (auto &user : node_users[prior_node]) {
  75. auto cnode = user.first->cast<CNodePtr>();
  76. MS_EXCEPTION_IF_NULL(cnode);
  77. if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
  78. prior_nodes->emplace_back(cnode);
  79. }
  80. }
  81. } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) {
  82. prior_nodes->emplace_back(prior_node);
  83. } else {
  84. return false;
  85. }
  86. if (behind_node->isa<Parameter>()) {
  87. for (auto &user : node_users[behind_node]) {
  88. auto cnode = user.first->cast<CNodePtr>();
  89. MS_EXCEPTION_IF_NULL(cnode);
  90. if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
  91. depend_nodes->emplace_back(cnode);
  92. }
  93. }
  94. } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) {
  95. depend_nodes->emplace_back(behind_node);
  96. } else {
  97. return false;
  98. }
  99. return true;
  100. }
  101. void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
  102. std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges,
  103. std::map<AnfNodePtr, size_t> *nodes_ref) {
  104. MS_EXCEPTION_IF_NULL(node);
  105. auto input_cnode = node->cast<CNodePtr>();
  106. MS_EXCEPTION_IF_NULL(input_cnode);
  107. auto prior_node = input_cnode->input(kControlDependPriorIndex);
  108. auto depend_node = input_cnode->input(kControlDependBehindIndex);
  109. MS_EXCEPTION_IF_NULL(prior_node);
  110. MS_EXCEPTION_IF_NULL(depend_node);
  111. PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
  112. MS_EXCEPTION_IF_NULL(prim_ptr);
  113. ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
  114. int depend_mode = 0;
  115. if (mode_ptr != nullptr) {
  116. depend_mode = GetValue<int>(mode_ptr);
  117. }
  118. if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) {
  119. return;
  120. }
  121. std::vector<AnfNodePtr> prior_nodes;
  122. std::vector<AnfNodePtr> behind_nodes;
  123. if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) {
  124. return;
  125. }
  126. for (auto &first_node : prior_nodes) {
  127. for (auto &second_node : behind_nodes) {
  128. MS_EXCEPTION_IF_NULL(first_node);
  129. MS_EXCEPTION_IF_NULL(second_node);
  130. auto iter = control_edges->find(second_node);
  131. if (iter == control_edges->end()) {
  132. (void)control_edges->insert(
  133. std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node}));
  134. } else {
  135. iter->second.emplace_back(first_node);
  136. }
  137. auto ref_iter = nodes_ref->find(first_node);
  138. if (ref_iter != nodes_ref->end()) {
  139. ref_iter->second++;
  140. } else {
  141. (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1));
  142. }
  143. }
  144. }
  145. }
  146. void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref,
  147. std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) {
  148. std::queue<AnfNodePtr> queue;
  149. queue.push(graph->get_return());
  150. std::set<AnfNodePtr> visited;
  151. while (!queue.empty()) {
  152. auto &node = queue.front();
  153. queue.pop();
  154. MS_EXCEPTION_IF_NULL(node);
  155. if (!node->isa<CNode>()) {
  156. continue;
  157. }
  158. auto cnode = node->cast<CNodePtr>();
  159. MS_EXCEPTION_IF_NULL(cnode);
  160. for (auto &input : cnode->inputs()) {
  161. if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) {
  162. AddControlEdge(graph, input, control_edges, nodes_ref);
  163. }
  164. auto iter = nodes_ref->find(input);
  165. if (iter != nodes_ref->end()) {
  166. iter->second++;
  167. } else {
  168. (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
  169. }
  170. if (visited.find(input) != visited.end()) {
  171. continue;
  172. }
  173. visited.insert(input);
  174. queue.push(input);
  175. }
  176. }
  177. }
  178. std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
  179. std::vector<AnfNodePtr> result;
  180. std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
  181. std::map<AnfNodePtr, size_t> node_positions;
  182. for (auto &node : nodes) {
  183. if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  184. auto cnode = node->cast<CNodePtr>();
  185. MS_EXCEPTION_IF_NULL(cnode);
  186. auto &inputs = cnode->inputs();
  187. if (inputs.size() < 2) {
  188. MS_LOG(EXCEPTION) << "Invalid get item node";
  189. }
  190. auto &parent = inputs[1];
  191. auto iter = node_positions.find(parent);
  192. if (iter != node_positions.end()) {
  193. size_t position = iter->second;
  194. auto iter_nodes = insert_positions.find(position);
  195. if (iter_nodes != insert_positions.end()) {
  196. iter_nodes->second.push_back(node);
  197. } else {
  198. (void)insert_positions.insert(
  199. std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
  200. }
  201. continue;
  202. }
  203. }
  204. result.emplace_back(node);
  205. node_positions[node] = result.size();
  206. }
  207. size_t insert_num = 0;
  208. for (auto &item : insert_positions) {
  209. size_t position = item.first + insert_num;
  210. (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
  211. insert_num += item.second.size();
  212. }
  213. return result;
  214. }
  215. std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
  216. std::vector<AnfNodePtr> result;
  217. std::stack<AnfNodePtr> to_visit;
  218. std::stack<AnfNodePtr> next_to_visit;
  219. std::map<AnfNodePtr, size_t> nodes_ref;
  220. std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
  221. CalcNodeRefCount(graph, &nodes_ref, &control_edges);
  222. std::string handle_target = default_target;
  223. std::string next_target = "";
  224. to_visit.push(graph->get_return());
  225. while (!to_visit.empty() || !next_to_visit.empty()) {
  226. if (to_visit.empty()) {
  227. to_visit.swap(next_to_visit);
  228. handle_target = next_target;
  229. }
  230. auto &node = to_visit.top();
  231. MS_EXCEPTION_IF_NULL(node);
  232. to_visit.pop();
  233. result.emplace_back(node);
  234. if (!node->isa<CNode>()) {
  235. continue;
  236. }
  237. auto cnode = node->cast<CNodePtr>();
  238. MS_EXCEPTION_IF_NULL(cnode);
  239. auto node_inputs = cnode->inputs();
  240. std::reverse(node_inputs.begin(), node_inputs.end());
  241. auto ctrl_inputs = control_edges.find(node);
  242. if (ctrl_inputs != control_edges.end()) {
  243. node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
  244. }
  245. for (auto &input : node_inputs) {
  246. auto iter = nodes_ref.find(input);
  247. if (iter != nodes_ref.end()) {
  248. iter->second--;
  249. if (iter->second != 0) {
  250. continue;
  251. }
  252. }
  253. if (!input->isa<CNode>()) {
  254. to_visit.push(input);
  255. continue;
  256. }
  257. std::string input_target = GetCNodeTarget(input);
  258. if (input_target == handle_target) {
  259. to_visit.push(input);
  260. } else if (next_to_visit.empty() || input_target == next_target) {
  261. next_to_visit.push(input);
  262. next_target = input_target;
  263. } else {
  264. MS_LOG(EXCEPTION) << "only support two different target";
  265. }
  266. }
  267. }
  268. std::reverse(result.begin(), result.end());
  269. return result;
  270. }
  271. bool IsSubGraph(const AnfNodePtr &node) {
  272. MS_EXCEPTION_IF_NULL(node);
  273. if (node->isa<CNode>()) {
  274. auto cnode = node->cast<CNodePtr>();
  275. auto &inputs = cnode->inputs();
  276. if (inputs.empty()) {
  277. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  278. }
  279. AnfNodePtr fn = inputs[0];
  280. if (!IsValueNode<Primitive>(fn)) {
  281. return false;
  282. }
  283. auto node_prim = GetValueNode<PrimitivePtr>(fn);
  284. if (node_prim->name() == prim::kPrimPartial->name()) {
  285. return true;
  286. }
  287. } else if (IsValueNode<FuncGraph>(node)) {
  288. return true;
  289. }
  290. return false;
  291. }
  292. } // namespace
  293. CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
  294. : backend_(backend), cut_list_(cut_list) {
  295. MS_EXCEPTION_IF_NULL(backend_);
  296. lin_convert_ = backend_->convert_fn();
  297. if (lin_convert_ == nullptr) {
  298. MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
  299. }
  300. is_gevm_convert_ = false;
  301. if (backend->name() == kGeVm) {
  302. MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true";
  303. is_gevm_convert_ = true;
  304. }
  305. }
  306. bool CompileGraph::IsCut(const AnfNodePtr &node) {
  307. MS_EXCEPTION_IF_NULL(node);
  308. if (node->isa<CNode>()) {
  309. auto cnode = node->cast<CNodePtr>();
  310. auto &inputs = cnode->inputs();
  311. if (inputs.empty()) {
  312. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  313. }
  314. AnfNodePtr fn = inputs[0];
  315. if (IsValueNode<FuncGraph>(fn)) {
  316. auto fg = GetValueNode<FuncGraphPtr>(fn);
  317. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  318. return false;
  319. }
  320. }
  321. if (!IsValueNode<Primitive>(fn)) {
  322. return true;
  323. }
  324. PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
  325. for (auto &prim : cut_list_) {
  326. MS_EXCEPTION_IF_NULL(prim);
  327. if (prim->name() == node_prim->name()) {
  328. if (prim->name() == prim::kPrimBpropCut->name()) {
  329. auto ms_context = MsContext::GetInstance();
  330. MS_EXCEPTION_IF_NULL(ms_context);
  331. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
  332. }
  333. if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
  334. if (inputs.size() < 2) {
  335. return false;
  336. }
  337. auto ret = IsSubGraph(inputs[1]);
  338. return ret;
  339. }
  340. return true;
  341. }
  342. }
  343. #ifdef ENABLE_GE
  344. if (is_gevm_convert_) {
  345. auto name = GetCNodeFuncName(cnode);
  346. auto adpt = transform::DfGraphConvertor::FindAdapter(name);
  347. if (adpt == nullptr) {
  348. return true;
  349. }
  350. }
  351. #endif
  352. }
  353. return false;
  354. }
  355. VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
  356. MS_EXCEPTION_IF_NULL(graph);
  357. auto nodes = OptimizeGetItemOrder(input_nodes);
  358. VectorRef splits;
  359. VectorRef split;
  360. std::string last_target;
  361. for (auto &node : nodes) {
  362. MS_EXCEPTION_IF_NULL(node);
  363. if (IsCut(node)) {
  364. if (split.size() != 0) {
  365. splits.push_back(split);
  366. }
  367. splits.push_back(node);
  368. split.clear();
  369. } else if (node->isa<CNode>()) {
  370. std::string cur_target = GetCNodeTarget(node);
  371. if (cur_target != last_target && !last_target.empty() && split.size() != 0) {
  372. splits.push_back(split);
  373. split.clear();
  374. }
  375. last_target = cur_target;
  376. split.push_back(node);
  377. }
  378. }
  379. return splits;
  380. }
  381. VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
  382. MS_EXCEPTION_IF_NULL(graph);
  383. auto nodes = TopoSort(graph->get_return());
  384. MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
  385. if (ContainMultiTarget(nodes)) {
  386. auto context_ptr = MsContext::GetInstance();
  387. MS_EXCEPTION_IF_NULL(context_ptr);
  388. std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  389. nodes = SplitSort(graph, default_target);
  390. return SplitNodesWithTarget(nodes, graph);
  391. }
  392. VectorRef splits;
  393. VectorRef split;
  394. for (auto &node : nodes) {
  395. MS_EXCEPTION_IF_NULL(node);
  396. if (IsCut(node)) {
  397. if (split.size() != 0) {
  398. splits.push_back(split);
  399. }
  400. splits.push_back(node);
  401. split.clear();
  402. } else if (node->isa<CNode>()) {
  403. split.push_back(node);
  404. }
  405. }
  406. return splits;
  407. }
  408. // Push the value node on the stack.
  409. void CompileGraph::Push(const AnfNodePtr &node) {
  410. MS_EXCEPTION_IF_NULL(node);
  411. if (slots_.count(node) > 0) {
  412. MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString()
  413. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  414. return;
  415. }
  416. MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
  417. << " is parameter: " << node->isa<Parameter>();
  418. slots_[node] = height_;
  419. set_height(height_ + 1);
  420. }
  421. void CompileGraph::AddInst(const Instruction &inst, const int &arg) {
  422. VectorRef args;
  423. args.push_back(arg);
  424. AddInst(inst, args);
  425. }
  426. void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
  427. VectorRef args;
  428. args.push_back(arg);
  429. AddInst(inst, args);
  430. }
  431. void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
  432. inst_.push_back(std::make_pair(inst, args));
  433. }
  434. // Gets the stack reference for the node value. If the node is a constant,
  435. // it may actually cause the push in to not be mentioned before.
  436. int CompileGraph::Ref(const AnfNodePtr &node) {
  437. MS_EXCEPTION_IF_NULL(node);
  438. MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
  439. if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
  440. if (IsValueNode<FuncGraph>(node)) {
  441. MS_LOG(DEBUG) << "Push graph.";
  442. AddInst(Instruction::kGraph, GetValueNode(node));
  443. } else {
  444. MS_LOG(DEBUG) << "Push.";
  445. if (IsValueNode<Primitive>(node)) {
  446. MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  447. } else {
  448. AddInst(Instruction::kPush, GetValueNode(node));
  449. }
  450. }
  451. Push(node);
  452. }
  453. MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
  454. << ", return: " << slots_[node] - height_;
  455. return slots_[node] - height_;
  456. }
  457. // Make sure the value of node is at the top of the stack.
  458. void CompileGraph::AddInput(const AnfNodePtr &node) {
  459. MS_EXCEPTION_IF_NULL(node);
  460. if (slots_.count(node) == 0) {
  461. MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
  462. (void)Ref(node);
  463. return;
  464. }
  465. AddInst(Instruction::kInput, Ref(node));
  466. set_height(height_ + 1);
  467. }
  468. // Call back effect in stack
  469. void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); }
  470. void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
  471. MS_EXCEPTION_IF_NULL(graph);
  472. std::vector<AnfNodePtr> parameters = graph->parameters();
  473. for (size_t i = parameters.size(); i != 0; i--) {
  474. Push(parameters[i - 1]);
  475. MS_LOG(DEBUG) << "Push parameter " << i - 1 << ": " << parameters[i - 1]->DebugString(true);
  476. }
  477. }
  478. int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) {
  479. MS_LOG(DEBUG) << "LinConvert start";
  480. LinConvertResult result;
  481. result = lin_convert_(node_list, target);
  482. if (result.run == nullptr) {
  483. MS_LOG(ERROR) << "LinConvert failed";
  484. return RET_FAILED;
  485. }
  486. if (!(*result.run)) {
  487. if (result.inputs.size() != result.outputs.size()) {
  488. MS_EXCEPTION_IF_NULL(graph);
  489. MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  490. } else {
  491. size_t size = result.inputs.size();
  492. for (size_t i = 0; i < size; i++) {
  493. Tie(result.inputs[i], result.outputs[i]);
  494. }
  495. return RET_CONTINUE;
  496. }
  497. }
  498. AddExternal(result);
  499. for (auto &o : result.outputs) {
  500. Push(o);
  501. }
  502. return RET_SUCCESS;
  503. }
  504. int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
  505. MS_EXCEPTION_IF_NULL(node);
  506. MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
  507. std::vector<AnfNodePtr> node_inputs = node->inputs();
  508. if (node_inputs.empty()) {
  509. MS_LOG(EXCEPTION) << "The node->inputs() is empty";
  510. }
  511. AnfNodePtr fn = node_inputs[0];
  512. if (IsValueNode<Primitive>(fn)) {
  513. PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
  514. MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
  515. for (size_t i = node_inputs.size() - 1; i > 0; i--) {
  516. AddInput(node->input(i));
  517. }
  518. if (IsPrimitive(fn, prim::kPrimReturn)) {
  519. AddReturn(node);
  520. return RET_BREAK;
  521. }
  522. if (IsPrimitive(fn, prim::kPrimPartial)) {
  523. AddPartial(node);
  524. } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
  525. AddSwitch(node);
  526. } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
  527. AddSwitchLayer(node);
  528. } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
  529. AddMakeTuple(node);
  530. } else {
  531. AddPrimitive(node, value);
  532. }
  533. } else {
  534. int ret = AddCall(graph, node);
  535. if (ret == RET_BREAK) {
  536. return ret;
  537. }
  538. }
  539. Push(node);
  540. return RET_SUCCESS;
  541. }
  542. bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
  543. MS_LOG(DEBUG) << "Start split graph";
  544. MS_EXCEPTION_IF_NULL(graph);
  545. VectorRef splits = SplitNodes(graph);
  546. MS_LOG(DEBUG) << "Split nodes size:" << splits.size();
  547. for (auto &split : splits) {
  548. int ret = RET_SUCCESS;
  549. if (utils::isa<VectorRef>(split)) {
  550. MS_LOG(DEBUG) << "Start a extern LinConvert";
  551. std::vector<AnfNodePtr> args;
  552. auto vec_ref = utils::cast<VectorRef>(split);
  553. (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
  554. [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
  555. if (args.size() > 0) {
  556. std::string cur_target = GetCNodeTarget(args[0]);
  557. ret = LinConvert(graph, args, cur_target);
  558. } else {
  559. ret = LinConvert(graph, args);
  560. }
  561. MS_LOG(DEBUG) << "End a extern LinConvert";
  562. if (ret == RET_FAILED) {
  563. return false;
  564. }
  565. if (ret == RET_CONTINUE) {
  566. continue;
  567. }
  568. } else {
  569. MS_LOG(DEBUG) << "Start a cut node";
  570. if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) {
  571. MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  572. }
  573. CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>();
  574. ret = InterpretNode(graph, node);
  575. MS_LOG(DEBUG) << "End a cut node";
  576. if (ret == RET_BREAK) {
  577. break;
  578. }
  579. }
  580. }
  581. MS_LOG(DEBUG) << "End split graph";
  582. return true;
  583. }
  584. InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
  585. MS_EXCEPTION_IF_NULL(graph);
  586. Reset();
  587. PushParameters(graph);
  588. int param_height = height_;
  589. MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
  590. if (!SplitGraph(graph)) {
  591. return inst_;
  592. }
  593. AddPadStack(param_height);
  594. auto ret = inst_;
  595. Reset();
  596. return ret;
  597. }
  598. void CompileGraph::AddPadStack(int param_height) {
  599. int stack_sizes = max_height_ - param_height;
  600. MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
  601. << " need_stack:" << stack_sizes;
  602. if (stack_sizes > 0) {
  603. VectorRef need_stacks({stack_sizes});
  604. (void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
  605. }
  606. }
  607. void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
  608. VectorRef args;
  609. args.emplace_back(Ref(fn));
  610. args.emplace_back(height_);
  611. args.emplace_back(static_cast<int>(size - 1));
  612. MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  613. AddInst(Instruction::kTailCall, args);
  614. }
  615. void CompileGraph::AddPartial(const CNodePtr &node) {
  616. auto inputs = node->inputs();
  617. VectorRef args;
  618. auto fn = inputs[1];
  619. if (!IsValueNode<FuncGraph>(fn)) {
  620. MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
  621. }
  622. for (size_t i = 1; i < inputs.size(); i++) {
  623. args.emplace_back(Ref(inputs[i]));
  624. }
  625. AddInst(Instruction::kPartial, args);
  626. }
  627. void CompileGraph::AddMakeTuple(const CNodePtr &node) {
  628. auto inputs = node->inputs();
  629. VectorRef args;
  630. for (size_t i = 1; i < inputs.size(); i++) {
  631. args.emplace_back(Ref(inputs[i]));
  632. }
  633. AddInst(Instruction::kTuple, args);
  634. }
  635. void CompileGraph::AddSwitch(const CNodePtr &node) {
  636. auto inputs = node->inputs();
  637. if (inputs.size() < 4) {
  638. MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
  639. }
  640. VectorRef args;
  641. args.emplace_back(Ref(inputs[1]));
  642. args.emplace_back(Ref(inputs[2]));
  643. args.emplace_back(Ref(inputs[3]));
  644. AddInst(Instruction::kSwitch, args);
  645. }
  646. void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
  647. auto inputs = node->inputs();
  648. if (inputs.size() != 3) {
  649. MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
  650. }
  651. VectorRef args;
  652. args.emplace_back(Ref(inputs[1]));
  653. args.emplace_back(Ref(inputs[2]));
  654. AddInst(Instruction::kSwitchLayer, args);
  655. }
  656. void CompileGraph::AddReturn(const CNodePtr &node) {
  657. VectorRef args;
  658. args.emplace_back(Ref(node->input(1)));
  659. args.emplace_back(height_);
  660. AddInst(Instruction::kReturn, args);
  661. }
  662. void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
  663. auto inputs = node->inputs();
  664. VectorRef args;
  665. args.push_back(prim);
  666. for (size_t i = 1; i < inputs.size(); i++) {
  667. args.emplace_back(Ref(inputs[i]));
  668. }
  669. AddInst(Instruction::kPrim, args);
  670. }
  671. int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
  672. auto inputs = node->inputs();
  673. AnfNodePtr fn = inputs[0];
  674. (void)Ref(fn);
  675. size_t size = inputs.size();
  676. for (size_t i = size - 1; i > 0; i--) {
  677. AddInput(inputs[i]);
  678. }
  679. if (node == graph->output()) {
  680. AddTailCall(fn, size);
  681. return RET_BREAK;
  682. }
  683. MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  684. AddInst(Instruction::kCall, Ref(fn));
  685. Ret(static_cast<int>(size - 1));
  686. return RET_SUCCESS;
  687. }
  688. void CompileGraph::AddExternal(const LinConvertResult &result) {
  689. VectorRef args;
  690. args.push_back(result.run);
  691. args.push_back(result.simu_run);
  692. size_t size = result.inputs.size();
  693. for (size_t i = 0; i < size; i++) {
  694. args.emplace_back(Ref(result.inputs[i]));
  695. }
  696. AddInst(Instruction::kExternal, args);
  697. }
  698. void TraverseGraphMap(
  699. const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs,
  700. const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
  701. MS_EXCEPTION_IF_NULL(manager_ptr);
  702. MS_EXCEPTION_IF_NULL(tr);
  703. for (const auto &fg : fgs) {
  704. for (const auto &ct_any : fg->value_nodes()) {
  705. AnfNodePtr const_primitive_node = ct_any.first;
  706. if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
  707. auto users = manager_ptr->node_users()[const_primitive_node];
  708. for (auto &use : users) {
  709. CNodePtr node = use.first->cast<CNodePtr>();
  710. MS_EXCEPTION_IF_NULL(node);
  711. if (node->func_graph() != fg) {
  712. continue;
  713. }
  714. int key = use.second;
  715. if (key != 0) {
  716. MS_EXCEPTION_IF_NULL(node->input(0));
  717. bool key_is_const = node->input(0)->isa<ValueNode>();
  718. PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
  719. if (value != nullptr) {
  720. bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
  721. bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
  722. if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
  723. continue;
  724. }
  725. }
  726. FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
  727. dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
  728. tr->SetEdge(node, key, NewValueNode(g));
  729. }
  730. }
  731. }
  732. }
  733. }
  734. }
  735. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
  736. MS_EXCEPTION_IF_NULL(graph);
  737. FuncGraphManagerPtr manager_ptr = graph->manager();
  738. MS_EXCEPTION_IF_NULL(manager_ptr);
  739. MapPrimTypeFuncGraph prim_graphs;
  740. auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
  741. PrimTypePair prim_type = std::make_pair(prim, type);
  742. if (prim_graphs.end() == prim_graphs.find(prim_type)) {
  743. FuncGraphPtr g = std::make_shared<FuncGraph>();
  744. std::vector<AnfNodePtr> args;
  745. ValueNodePtr prim_ct = NewValueNode(prim);
  746. MS_EXCEPTION_IF_NULL(prim_ct);
  747. prim_ct->set_abstract(type);
  748. args.push_back(prim_ct);
  749. MS_EXCEPTION_IF_NULL(type);
  750. TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
  751. MS_EXCEPTION_IF_NULL(tp);
  752. MS_EXCEPTION_IF_NULL(g);
  753. for (auto t : tp->args_spec_list()) {
  754. ParameterPtr p = g->add_parameter();
  755. p->set_abstract(t);
  756. args.push_back(p);
  757. }
  758. AnfNodePtr out = g->NewCNode(args);
  759. out->set_abstract(tp->output());
  760. g->set_output(out);
  761. prim_graphs[prim_type] = g;
  762. }
  763. return prim_graphs[prim_type];
  764. };
  765. FuncGraphTransaction tr = manager_ptr->Transact();
  766. auto &fgs = manager_ptr->func_graphs();
  767. TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
  768. tr.Commit();
  769. return graph;
  770. }
  771. CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
  772. MS_EXCEPTION_IF_NULL(backend);
  773. MS_LOG(DEBUG) << "Start vm: " << backend->name();
  774. transform_ = std::make_shared<CompileGraph>(backend, cut_list);
  775. Reset();
  776. }
  777. // Convert graphs to unlinked instructions.
  778. void CompileGraphs::Compile(const FuncGraphPtr &graph) {
  779. MS_LOG(DEBUG) << "Start";
  780. mapping_[graph] = static_cast<int>(insts_.size());
  781. if (transform_ != nullptr) {
  782. InstSet insts = transform_->Run(graph);
  783. if (!insts.empty()) {
  784. (void)insts_.insert(insts_.end(), insts.begin(), insts.end());
  785. }
  786. }
  787. MS_LOG(DEBUG) << "End";
  788. }
  789. // Link instructions from multiple function graphs together.
  790. FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
  791. MS_LOG(DEBUG) << "Start";
  792. for (std::size_t i = 0; i < insts_.size(); i++) {
  793. InstType inst = insts_[i];
  794. MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
  795. if (Instruction::kGraph == inst.first) {
  796. if (inst.second.empty()) {
  797. MS_LOG(EXCEPTION) << "The second element of inst is empty";
  798. }
  799. FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
  800. MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
  801. insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
  802. }
  803. }
  804. FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
  805. MS_LOG(DEBUG) << "End";
  806. return rt;
  807. }
  808. // Convert all graphs to unlinked instructions and link them.
  809. FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
  810. MS_EXCEPTION_IF_NULL(graph);
  811. MS_LOG(DEBUG) << "Start";
  812. Reset();
  813. MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
  814. FuncGraphPtr prim_graph = WrapPrimitives(graph);
  815. Compile(prim_graph);
  816. MS_EXCEPTION_IF_NULL(prim_graph);
  817. FuncGraphSet graphs = prim_graph->manager()->func_graphs();
  818. for (auto g : graphs) {
  819. if (g != graph && g != nullptr && !(g->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
  820. Compile(g);
  821. }
  822. }
  823. FinalVMPtr rt = Link(graph);
  824. Reset();
  825. MS_LOG(DEBUG) << "End";
  826. return rt;
  827. }
  828. bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
  829. MS_EXCEPTION_IF_NULL(graph);
  830. auto graph_manager = graph->manager();
  831. MS_EXCEPTION_IF_NULL(graph_manager);
  832. FuncGraphSet graphs = graph_manager->func_graphs();
  833. for (auto &g : graphs) {
  834. auto nodes = TopoSort(g->get_return());
  835. if (ContainMultiTarget(nodes)) {
  836. return true;
  837. }
  838. }
  839. return false;
  840. }
  841. BackendPtr CreateBackend() {
  842. auto context_ptr = MsContext::GetInstance();
  843. MS_EXCEPTION_IF_NULL(context_ptr);
  844. std::string name = context_ptr->backend_policy();
  845. MS_LOG(INFO) << "CreateBackend is: " << name;
  846. if (backend_list.count(name) == 0) {
  847. MS_LOG(EXCEPTION) << "Backend is error: " << name;
  848. }
  849. if (name == kMsConvert) {
  850. std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  851. uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  852. auto backend = std::make_shared<MsBackend>(name, target, device_id);
  853. std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  854. if (device_target == kAscendDevice) {
  855. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  856. backend->set_is_multi_graph_sink(false);
  857. context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
  858. } else {
  859. backend->set_is_multi_graph_sink(true);
  860. context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
  861. }
  862. }
  863. return backend;
  864. }
  865. return std::make_shared<Backend>(name);
  866. }
  867. } // namespace compile
  868. } // namespace mindspore