You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.cc 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/transform.h"
  19. #include <algorithm>
  20. #include <map>
  21. #include <queue>
  22. #include <stack>
  23. #include <set>
  24. #include <string>
  25. #include <vector>
  26. #include "pipeline/static_analysis/abstract_value.h"
  27. #ifdef ENABLE_GE
  28. #include "transform/convert.h"
  29. #endif
  30. #include "utils/graph_utils.h"
  31. #include "utils/context/ms_context.h"
  32. #include "debug/trace.h"
  33. namespace mindspore {
  34. namespace compile {
  35. using mindspore::abstract::AbstractFunction;
  36. using mindspore::abstract::AbstractFunctionPtr;
  37. using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
  38. using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
  39. using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
  40. std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  41. prim::kPrimMakeTuple, prim::kPrimBpropCut};
  42. const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
  43. static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
  44. prim::kPrimBpropCut};
  45. return ms_nonlinear_ops;
  46. }
  47. namespace {
  48. std::string GetCNodeTarget(const AnfNodePtr &node) {
  49. auto context_ptr = MsContext::GetInstance();
  50. MS_EXCEPTION_IF_NULL(context_ptr);
  51. std::string default_target = context_ptr->device_target();
  52. if (!node->isa<CNode>()) {
  53. return default_target;
  54. }
  55. auto cnode = node->cast<CNodePtr>();
  56. MS_EXCEPTION_IF_NULL(cnode);
  57. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  58. if (attr_input == nullptr) {
  59. return default_target;
  60. }
  61. auto value_node = attr_input->cast<ValueNodePtr>();
  62. if (value_node == nullptr) {
  63. return default_target;
  64. }
  65. auto value = value_node->value();
  66. if (value == nullptr) {
  67. return default_target;
  68. }
  69. if (!value->isa<Primitive>()) {
  70. return default_target;
  71. }
  72. auto primitive = value->cast<PrimitivePtr>();
  73. auto att_target = primitive->GetAttr("primitive_target");
  74. if (att_target != nullptr) {
  75. if (!att_target->isa<StringImm>()) {
  76. MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
  77. }
  78. auto target = GetValue<std::string>(att_target);
  79. if (kTargetSet.find(target) == kTargetSet.end()) {
  80. MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
  81. }
  82. return target;
  83. }
  84. return default_target;
  85. }
  86. bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
  87. auto context_ptr = MsContext::GetInstance();
  88. MS_EXCEPTION_IF_NULL(context_ptr);
  89. std::string last_target = context_ptr->device_target();
  90. for (auto &node : nodes) {
  91. if (node->isa<CNode>()) {
  92. std::string cur_target = GetCNodeTarget(node);
  93. if (last_target != cur_target) {
  94. return true;
  95. }
  96. last_target = cur_target;
  97. }
  98. }
  99. return false;
  100. }
  101. void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
  102. std::queue<AnfNodePtr> queue;
  103. queue.push(graph->get_return());
  104. std::set<AnfNodePtr> visited;
  105. while (!queue.empty()) {
  106. auto &node = queue.front();
  107. queue.pop();
  108. MS_EXCEPTION_IF_NULL(node);
  109. if (!node->isa<CNode>()) {
  110. continue;
  111. }
  112. auto cnode = node->cast<CNodePtr>();
  113. MS_EXCEPTION_IF_NULL(cnode);
  114. for (auto &input : cnode->inputs()) {
  115. auto iter = nodes_ref->find(input);
  116. if (iter != nodes_ref->end()) {
  117. iter->second++;
  118. } else {
  119. (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
  120. }
  121. if (visited.find(input) != visited.end()) {
  122. continue;
  123. }
  124. visited.insert(input);
  125. queue.push(input);
  126. }
  127. }
  128. }
  129. bool IsGetItemNode(const AnfNodePtr &node) {
  130. MS_EXCEPTION_IF_NULL(node);
  131. if (node->isa<CNode>()) {
  132. auto cnode = node->cast<CNodePtr>();
  133. auto &inputs = cnode->inputs();
  134. if (inputs.empty()) {
  135. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  136. }
  137. if (!IsValueNode<Primitive>(inputs[0])) {
  138. return true;
  139. }
  140. PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
  141. return node_prim->name() == prim::kPrimTupleGetItem->name();
  142. }
  143. return false;
  144. }
  145. std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
  146. std::vector<AnfNodePtr> result;
  147. std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
  148. std::map<AnfNodePtr, size_t> node_positions;
  149. for (auto &node : nodes) {
  150. if (IsGetItemNode(node)) {
  151. auto cnode = node->cast<CNodePtr>();
  152. MS_EXCEPTION_IF_NULL(cnode);
  153. auto &inputs = cnode->inputs();
  154. if (inputs.size() < 2) {
  155. MS_LOG(EXCEPTION) << "Invalid get item node";
  156. }
  157. auto &parent = inputs[1];
  158. auto iter = node_positions.find(parent);
  159. if (iter != node_positions.end()) {
  160. size_t position = iter->second;
  161. auto iter_nodes = insert_positions.find(position);
  162. if (iter_nodes != insert_positions.end()) {
  163. iter_nodes->second.push_back(node);
  164. } else {
  165. (void)insert_positions.insert(
  166. std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
  167. }
  168. continue;
  169. }
  170. }
  171. result.emplace_back(node);
  172. node_positions[node] = result.size();
  173. }
  174. size_t insert_num = 0;
  175. for (auto &item : insert_positions) {
  176. size_t position = item.first + insert_num;
  177. (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
  178. insert_num += item.second.size();
  179. }
  180. return result;
  181. }
  182. std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
  183. std::vector<AnfNodePtr> result;
  184. std::stack<AnfNodePtr> to_visit;
  185. std::stack<AnfNodePtr> next_to_visit;
  186. std::map<AnfNodePtr, size_t> nodes_ref;
  187. CalcNodeRefCount(graph, &nodes_ref);
  188. std::string handle_target = default_target;
  189. std::string next_target = "";
  190. to_visit.push(graph->get_return());
  191. while (!to_visit.empty() || !next_to_visit.empty()) {
  192. if (to_visit.empty()) {
  193. to_visit.swap(next_to_visit);
  194. handle_target = next_target;
  195. }
  196. auto &node = to_visit.top();
  197. MS_EXCEPTION_IF_NULL(node);
  198. to_visit.pop();
  199. result.emplace_back(node);
  200. if (!node->isa<CNode>()) {
  201. continue;
  202. }
  203. auto cnode = node->cast<CNodePtr>();
  204. MS_EXCEPTION_IF_NULL(cnode);
  205. auto node_inputs = cnode->inputs();
  206. std::reverse(node_inputs.begin(), node_inputs.end());
  207. for (auto &input : node_inputs) {
  208. auto iter = nodes_ref.find(input);
  209. if (iter != nodes_ref.end()) {
  210. iter->second--;
  211. if (iter->second != 0) {
  212. continue;
  213. }
  214. }
  215. if (!input->isa<CNode>()) {
  216. to_visit.push(input);
  217. continue;
  218. }
  219. std::string input_target = GetCNodeTarget(input);
  220. if (input_target == handle_target) {
  221. to_visit.push(input);
  222. } else if (next_to_visit.empty() || input_target == next_target) {
  223. next_to_visit.push(input);
  224. next_target = input_target;
  225. } else {
  226. MS_LOG(EXCEPTION) << "only support two different target";
  227. }
  228. }
  229. }
  230. std::reverse(result.begin(), result.end());
  231. return ReorderGetItemNode(result);
  232. }
  233. } // namespace
  234. CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
  235. : backend_(backend), cut_list_(cut_list) {
  236. MS_EXCEPTION_IF_NULL(backend_);
  237. lin_convert_ = backend_->convert_fn();
  238. if (lin_convert_ == nullptr) {
  239. MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
  240. }
  241. is_gevm_convert_ = false;
  242. if (backend->name() == kGeVm) {
  243. MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true";
  244. is_gevm_convert_ = true;
  245. }
  246. }
  247. bool CompileGraph::IsCut(const AnfNodePtr &node) {
  248. MS_EXCEPTION_IF_NULL(node);
  249. if (node->isa<CNode>()) {
  250. auto cnode = node->cast<CNodePtr>();
  251. auto &inputs = cnode->inputs();
  252. if (inputs.empty()) {
  253. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  254. }
  255. AnfNodePtr fn = inputs[0];
  256. if (!IsValueNode<Primitive>(fn)) {
  257. return true;
  258. }
  259. PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
  260. for (auto &prim : cut_list_) {
  261. MS_EXCEPTION_IF_NULL(prim);
  262. if (prim->name() == node_prim->name()) {
  263. if (prim->name() == prim::kPrimBpropCut->name()) {
  264. auto ms_context = MsContext::GetInstance();
  265. MS_EXCEPTION_IF_NULL(ms_context);
  266. ms_context->set_enable_pynative_hook(true);
  267. }
  268. return true;
  269. }
  270. }
  271. #ifdef ENABLE_GE
  272. if (is_gevm_convert_) {
  273. auto name = GetCNodeFuncName(cnode);
  274. auto adpt = transform::DfGraphConvertor::FindAdapter(name);
  275. if (adpt == nullptr) {
  276. return true;
  277. }
  278. }
  279. #endif
  280. }
  281. return false;
  282. }
  283. VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
  284. MS_EXCEPTION_IF_NULL(graph);
  285. VectorRef splits;
  286. VectorRef split;
  287. auto nodes = TopoSort(graph->get_return());
  288. if (ContainMultiTarget(nodes)) {
  289. auto context_ptr = MsContext::GetInstance();
  290. MS_EXCEPTION_IF_NULL(context_ptr);
  291. std::string default_target = context_ptr->device_target();
  292. nodes = SplitSort(graph, default_target);
  293. }
  294. std::string last_target;
  295. MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
  296. for (auto &node : nodes) {
  297. MS_EXCEPTION_IF_NULL(node);
  298. if (IsCut(node)) {
  299. MS_LOG(DEBUG) << "Cut node:" << node->DebugString(10) << ", size:" << split.size();
  300. if (split.size() != 0) {
  301. splits.push_back(split);
  302. }
  303. splits.push_back(node);
  304. split.clear();
  305. } else if (node->isa<CNode>()) {
  306. std::string cur_target = GetCNodeTarget(node);
  307. if (cur_target != last_target && !last_target.empty() && split.size() != 0) {
  308. splits.push_back(split);
  309. split.clear();
  310. }
  311. last_target = cur_target;
  312. split.push_back(node);
  313. MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size();
  314. }
  315. }
  316. MS_LOG(DEBUG) << "Split node size :" << splits.size();
  317. return splits;
  318. }
  319. // Push the value node on the stack.
  320. void CompileGraph::Push(const AnfNodePtr &node) {
  321. MS_EXCEPTION_IF_NULL(node);
  322. if (slots_.count(node) > 0) {
  323. MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString()
  324. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  325. }
  326. MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
  327. << " is parameter: " << node->isa<Parameter>();
  328. slots_[node] = height_;
  329. set_height(height_ + 1);
  330. }
  331. void CompileGraph::AddInst(const Instruction &inst, const int &arg) {
  332. VectorRef args;
  333. args.push_back(arg);
  334. AddInst(inst, args);
  335. }
  336. void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
  337. VectorRef args;
  338. args.push_back(arg);
  339. AddInst(inst, args);
  340. }
  341. void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
  342. inst_.push_back(std::make_pair(inst, args));
  343. }
  344. // Gets the stack reference for the node value. If the node is a constant,
  345. // it may actually cause the push in to not be mentioned before.
  346. int CompileGraph::Ref(const AnfNodePtr &node) {
  347. MS_EXCEPTION_IF_NULL(node);
  348. MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
  349. if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
  350. if (IsValueNode<FuncGraph>(node)) {
  351. MS_LOG(DEBUG) << "Push graph.";
  352. AddInst(Instruction::kGraph, GetValueNode(node));
  353. } else {
  354. MS_LOG(DEBUG) << "Push.";
  355. if (IsValueNode<Primitive>(node)) {
  356. MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  357. } else {
  358. AddInst(Instruction::kPush, GetValueNode(node));
  359. }
  360. }
  361. Push(node);
  362. }
  363. MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
  364. << ", return: " << slots_[node] - height_;
  365. return slots_[node] - height_;
  366. }
  367. // Make sure the value of node is at the top of the stack.
  368. void CompileGraph::AddInput(const AnfNodePtr &node) {
  369. MS_EXCEPTION_IF_NULL(node);
  370. if (slots_.count(node) == 0) {
  371. MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
  372. (void)Ref(node);
  373. return;
  374. }
  375. AddInst(Instruction::kInput, Ref(node));
  376. set_height(height_ + 1);
  377. }
  378. // Call back effect in stack
  379. void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); }
  380. void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
  381. MS_EXCEPTION_IF_NULL(graph);
  382. std::vector<AnfNodePtr> parameters = graph->parameters();
  383. for (size_t i = parameters.size(); i != 0; i--) {
  384. Push(parameters[i - 1]);
  385. MS_LOG(DEBUG) << "Push parameter " << i - 1 << ": " << parameters[i - 1]->DebugString(true);
  386. }
  387. }
  388. int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) {
  389. MS_LOG(DEBUG) << "LinConvert start";
  390. LinConvertResult result;
  391. if (backend_->simu_flag()) {
  392. result = backend_->GetMultiGraphRun(graph);
  393. } else {
  394. result = lin_convert_(node_list, target);
  395. }
  396. if (result.run == nullptr) {
  397. MS_LOG(ERROR) << "LinConvert failed";
  398. return RET_FAILED;
  399. }
  400. if (!(*result.run)) {
  401. if (result.inputs.size() != result.outputs.size()) {
  402. MS_EXCEPTION_IF_NULL(graph);
  403. MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  404. } else {
  405. size_t size = result.inputs.size();
  406. for (size_t i = 0; i < size; i++) {
  407. Tie(result.inputs[i], result.outputs[i]);
  408. }
  409. return RET_CONTINUE;
  410. }
  411. }
  412. AddExternal(result);
  413. for (auto &o : result.outputs) {
  414. Push(o);
  415. }
  416. return RET_SUCCESS;
  417. }
  418. void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
  419. MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
  420. if (backend_->is_multi_graph_sink()) {
  421. VectorRef args;
  422. args.emplace_back(-1);
  423. MS_LOG(DEBUG) << "call::" << height_;
  424. AddInst(Instruction::kCall, args);
  425. args.clear();
  426. args.emplace_back(node->input(1));
  427. AddInst(Instruction::kSwitchReturn, args);
  428. args.clear();
  429. args.emplace_back(false);
  430. args.emplace_back(Ref(node->input(1)));
  431. args.emplace_back(Ref(node->input(2)));
  432. args.emplace_back(Ref(node->input(3)));
  433. AddInst(Instruction::kSwitch, args);
  434. }
  435. }
  436. int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
  437. MS_EXCEPTION_IF_NULL(node);
  438. MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
  439. std::vector<AnfNodePtr> node_inputs = node->inputs();
  440. if (node_inputs.empty()) {
  441. MS_LOG(EXCEPTION) << "The node->inputs() is empty";
  442. }
  443. AnfNodePtr fn = node_inputs[0];
  444. if (IsValueNode<Primitive>(fn)) {
  445. PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
  446. MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
  447. for (size_t i = node_inputs.size() - 1; i > 0; i--) {
  448. AddInput(node->input(i));
  449. }
  450. if (IsPrimitive(fn, prim::kPrimReturn)) {
  451. AddReturn(node);
  452. return RET_BREAK;
  453. }
  454. if (IsPrimitive(fn, prim::kPrimPartial)) {
  455. AddPartial(node);
  456. } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
  457. AddSwitch(node);
  458. AddSinkSwitch(node);
  459. } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
  460. AddMakeTuple(node);
  461. } else {
  462. AddPrimitive(node, value);
  463. }
  464. } else {
  465. int ret = AddCall(graph, node);
  466. if (ret == RET_BREAK) {
  467. return ret;
  468. }
  469. }
  470. Push(node);
  471. return RET_SUCCESS;
  472. }
  473. void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
  474. auto ret = LinConvert(graph, {});
  475. if (ret == RET_FAILED) {
  476. MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
  477. }
  478. AddReturn(nullptr);
  479. }
  480. bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
  481. MS_LOG(DEBUG) << "Start split graph";
  482. MS_EXCEPTION_IF_NULL(graph);
  483. VectorRef splits = SplitNodes(graph);
  484. MS_LOG(DEBUG) << "Split nodes size:" << splits.size();
  485. for (auto &split : splits) {
  486. int ret = RET_SUCCESS;
  487. if (utils::isa<VectorRef>(split)) {
  488. MS_LOG(DEBUG) << "Start a extern LinConvert";
  489. std::vector<AnfNodePtr> args;
  490. auto vec_ref = utils::cast<VectorRef>(split);
  491. (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
  492. [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
  493. if (args.size() > 0) {
  494. std::string cur_target = GetCNodeTarget(args[0]);
  495. ret = LinConvert(graph, args, cur_target);
  496. } else {
  497. ret = LinConvert(graph, args);
  498. }
  499. MS_LOG(DEBUG) << "End a extern LinConvert";
  500. if (ret == RET_FAILED) {
  501. return false;
  502. }
  503. if (ret == RET_CONTINUE) {
  504. continue;
  505. }
  506. } else {
  507. MS_LOG(DEBUG) << "Start a cut node";
  508. if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) {
  509. MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
  510. }
  511. CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>();
  512. ret = InterpretNode(graph, node);
  513. MS_LOG(DEBUG) << "End a cut node";
  514. if (ret == RET_BREAK) {
  515. break;
  516. }
  517. }
  518. }
  519. MS_LOG(DEBUG) << "End split graph";
  520. return true;
  521. }
  522. InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
  523. InstSet inst = Run(graph);
  524. return inst;
  525. }
  526. InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
  527. MS_EXCEPTION_IF_NULL(graph);
  528. MS_LOG(DEBUG) << "Compile start graph: " << graph->ToString();
  529. Reset();
  530. PushParameters(graph);
  531. int param_height = height_;
  532. MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
  533. if (backend_->simu_flag()) {
  534. GenMultiGraphsRun(graph);
  535. } else {
  536. if (!SplitGraph(graph)) {
  537. return inst_;
  538. }
  539. }
  540. AddPadStack(param_height);
  541. auto ret = inst_;
  542. Reset();
  543. return ret;
  544. }
  545. void CompileGraph::AddPadStack(int param_height) {
  546. int stack_sizes = max_height_ - param_height;
  547. MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
  548. << " need_stack:" << stack_sizes;
  549. if (stack_sizes > 0) {
  550. VectorRef need_stacks({stack_sizes});
  551. (void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
  552. }
  553. }
  554. void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
  555. VectorRef args;
  556. args.emplace_back(Ref(fn));
  557. args.emplace_back(height_);
  558. args.emplace_back(static_cast<int>(size - 1));
  559. MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  560. AddInst(Instruction::kTailCall, args);
  561. }
  562. void CompileGraph::AddPartial(const CNodePtr &node) {
  563. auto inputs = node->inputs();
  564. VectorRef args;
  565. auto fn = inputs[1];
  566. if (!IsValueNode<FuncGraph>(fn)) {
  567. MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
  568. }
  569. if (backend_->is_multi_graph_sink()) {
  570. auto func_graph = GetValueNode<FuncGraphPtr>(fn);
  571. args.emplace_back(func_graph);
  572. AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
  573. backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
  574. }
  575. for (size_t i = 1; i < inputs.size(); i++) {
  576. args.emplace_back(Ref(inputs[i]));
  577. }
  578. AddInst(Instruction::kPartial, args);
  579. }
  580. void CompileGraph::AddMakeTuple(const CNodePtr &node) {
  581. auto inputs = node->inputs();
  582. VectorRef args;
  583. for (size_t i = 1; i < inputs.size(); i++) {
  584. args.emplace_back(Ref(inputs[i]));
  585. }
  586. AddInst(Instruction::kTuple, args);
  587. }
  588. void CompileGraph::AddSwitch(const CNodePtr &node) {
  589. auto inputs = node->inputs();
  590. if (inputs.size() < 4) {
  591. MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
  592. }
  593. VectorRef args;
  594. if (backend_->is_multi_graph_sink()) {
  595. args.emplace_back(true);
  596. }
  597. args.emplace_back(Ref(inputs[1]));
  598. args.emplace_back(Ref(inputs[2]));
  599. args.emplace_back(Ref(inputs[3]));
  600. AddInst(Instruction::kSwitch, args);
  601. }
  602. void CompileGraph::AddReturn(const CNodePtr &node) {
  603. VectorRef args;
  604. if (backend_->simu_flag()) {
  605. args.emplace_back(Ref(backend_->final_output()));
  606. } else {
  607. args.emplace_back(Ref(node->input(1)));
  608. }
  609. args.emplace_back(height_);
  610. AddInst(Instruction::kReturn, args);
  611. }
  612. void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
  613. auto inputs = node->inputs();
  614. VectorRef args;
  615. args.push_back(prim);
  616. for (size_t i = 1; i < inputs.size(); i++) {
  617. args.emplace_back(Ref(inputs[i]));
  618. }
  619. AddInst(Instruction::kPrim, args);
  620. }
  621. int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
  622. auto inputs = node->inputs();
  623. AnfNodePtr fn = inputs[0];
  624. if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
  625. auto func_graph = GetValueNode<FuncGraphPtr>(fn);
  626. AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
  627. backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
  628. }
  629. (void)Ref(fn);
  630. size_t size = inputs.size();
  631. for (size_t i = size - 1; i > 0; i--) {
  632. AddInput(inputs[i]);
  633. }
  634. if (node == graph->output()) {
  635. AddTailCall(fn, size);
  636. return RET_BREAK;
  637. }
  638. MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
  639. AddInst(Instruction::kCall, Ref(fn));
  640. Ret(static_cast<int>(size - 1));
  641. return RET_SUCCESS;
  642. }
  643. void CompileGraph::AddExternal(const LinConvertResult &result) {
  644. VectorRef args;
  645. args.push_back(result.run);
  646. args.push_back(result.simu_run);
  647. size_t size = result.inputs.size();
  648. for (size_t i = 0; i < size; i++) {
  649. args.emplace_back(Ref(result.inputs[i]));
  650. }
  651. AddInst(Instruction::kExternal, args);
  652. }
  653. void TraverseGraphMap(
  654. const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs,
  655. const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
  656. MS_EXCEPTION_IF_NULL(manager_ptr);
  657. MS_EXCEPTION_IF_NULL(tr);
  658. for (const auto &fg : fgs) {
  659. for (const auto &ct_any : fg->value_nodes()) {
  660. AnfNodePtr const_primitive_node = ct_any.first;
  661. if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
  662. auto users = manager_ptr->node_users()[const_primitive_node];
  663. for (auto &use : users) {
  664. CNodePtr node = use.first->cast<CNodePtr>();
  665. MS_EXCEPTION_IF_NULL(node);
  666. if (node->func_graph() != fg) {
  667. continue;
  668. }
  669. int key = use.second;
  670. if (key != 0) {
  671. MS_EXCEPTION_IF_NULL(node->input(0));
  672. bool key_is_const = node->input(0)->isa<ValueNode>();
  673. PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
  674. if (value != nullptr) {
  675. bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
  676. bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
  677. if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
  678. continue;
  679. }
  680. }
  681. FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
  682. dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
  683. tr->SetEdge(node, key, NewValueNode(g));
  684. }
  685. }
  686. }
  687. }
  688. }
  689. }
  690. FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
  691. MS_EXCEPTION_IF_NULL(graph);
  692. FuncGraphManagerPtr manager_ptr = graph->manager();
  693. MS_EXCEPTION_IF_NULL(manager_ptr);
  694. MapPrimTypeFuncGraph prim_graphs;
  695. auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
  696. PrimTypePair prim_type = std::make_pair(prim, type);
  697. if (prim_graphs.end() == prim_graphs.find(prim_type)) {
  698. FuncGraphPtr g = std::make_shared<FuncGraph>();
  699. std::vector<AnfNodePtr> args;
  700. ValueNodePtr prim_ct = NewValueNode(prim);
  701. MS_EXCEPTION_IF_NULL(prim_ct);
  702. prim_ct->set_abstract(type);
  703. args.push_back(prim_ct);
  704. MS_EXCEPTION_IF_NULL(type);
  705. TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
  706. MS_EXCEPTION_IF_NULL(tp);
  707. MS_EXCEPTION_IF_NULL(g);
  708. for (auto t : tp->args_spec_list()) {
  709. ParameterPtr p = g->add_parameter();
  710. p->set_abstract(t);
  711. args.push_back(p);
  712. }
  713. AnfNodePtr out = g->NewCNode(args);
  714. out->set_abstract(tp->output());
  715. g->set_output(out);
  716. prim_graphs[prim_type] = g;
  717. }
  718. return prim_graphs[prim_type];
  719. };
  720. FuncGraphTransaction tr = manager_ptr->Transact();
  721. auto &fgs = manager_ptr->func_graphs();
  722. TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
  723. tr.Commit();
  724. return graph;
  725. }
  726. CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
  727. MS_EXCEPTION_IF_NULL(backend);
  728. MS_LOG(DEBUG) << "Start vm: " << backend->name();
  729. transform_ = std::make_shared<CompileGraph>(backend, cut_list);
  730. Reset();
  731. }
  732. // Convert graphs to unlinked instructions.
  733. void CompileGraphs::Compile(const FuncGraphPtr &graph) {
  734. MS_LOG(DEBUG) << "Start";
  735. auto graph_manager = graph->manager();
  736. MS_EXCEPTION_IF_NULL(graph_manager);
  737. FuncGraphSet graphs = graph_manager->func_graphs();
  738. for (auto &g : graphs) {
  739. mapping_[g] = static_cast<int>(insts_.size());
  740. if (transform_ != nullptr) {
  741. InstSet insts = transform_->Run(g);
  742. if (!insts.empty()) {
  743. (void)insts_.insert(insts_.end(), insts.begin(), insts.end());
  744. }
  745. }
  746. }
  747. MS_LOG(DEBUG) << "End";
  748. }
  749. // Link instructions from multiple function graphs together.
  750. FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
  751. MS_LOG(DEBUG) << "Start";
  752. for (std::size_t i = 0; i < insts_.size(); i++) {
  753. InstType inst = insts_[i];
  754. MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
  755. if (Instruction::kGraph == inst.first) {
  756. if (inst.second.empty()) {
  757. MS_LOG(EXCEPTION) << "The second element of inst is empty";
  758. }
  759. FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
  760. MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
  761. insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
  762. }
  763. }
  764. FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
  765. if (backend_->is_multi_graph_sink()) {
  766. backend_->set_simu_flag(true);
  767. MS_LOG(DEBUG) << "Start simulate";
  768. backend_->SimulateRun(rt, graph);
  769. MS_LOG(DEBUG) << "Link graphs";
  770. insts_ = transform_->GenMultiGraphsSinkInst(graph);
  771. rt->set_insts(insts_);
  772. backend_->set_simu_flag(false);
  773. MS_LOG(DEBUG) << "End start simulate";
  774. backend_->Link(kInvalidGraphId);
  775. }
  776. MS_LOG(DEBUG) << "End";
  777. return rt;
  778. }
  779. // Convert all graphs to unlinked instructions and link them.
  780. FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
  781. MS_EXCEPTION_IF_NULL(graph);
  782. MS_LOG(DEBUG) << "Start";
  783. Reset();
  784. MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
  785. (void)WrapPrimitives(graph);
  786. Compile(graph);
  787. FinalVMPtr rt = Link(graph);
  788. Reset();
  789. MS_LOG(DEBUG) << "End";
  790. return rt;
  791. }
  792. bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
  793. auto graph_manager = graph->manager();
  794. MS_EXCEPTION_IF_NULL(graph_manager);
  795. FuncGraphSet graphs = graph_manager->func_graphs();
  796. for (auto &g : graphs) {
  797. auto nodes = TopoSort(g->get_return());
  798. if (ContainMultiTarget(nodes)) {
  799. return true;
  800. }
  801. }
  802. return false;
  803. }
  804. BackendPtr CreateBackend() {
  805. auto context_ptr = MsContext::GetInstance();
  806. MS_EXCEPTION_IF_NULL(context_ptr);
  807. std::string name = context_ptr->backend_policy();
  808. MS_LOG(INFO) << "CreateBackend is: " << name;
  809. if (backend_list.count(name) == 0) {
  810. MS_LOG(EXCEPTION) << "Backend is error: " << name;
  811. }
  812. if (name == kMsConvert) {
  813. std::string target = context_ptr->device_target();
  814. uint32_t device_id = context_ptr->device_id();
  815. auto backend = std::make_shared<MsBackend>(name, target, device_id);
  816. std::string device_target = MsContext::GetInstance()->device_target();
  817. if (device_target == kAscendDevice) {
  818. if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
  819. backend->set_is_multi_graph_sink(false);
  820. context_ptr->set_is_multi_graph_sink(false);
  821. } else {
  822. backend->set_is_multi_graph_sink(true);
  823. context_ptr->set_is_multi_graph_sink(true);
  824. }
  825. }
  826. return backend;
  827. }
  828. return std::make_shared<Backend>(name);
  829. }
  830. } // namespace compile
  831. } // namespace mindspore