You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_partition.cc 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "vm/graph_partition.h"
  17. #include <string>
  18. #include <functional>
  19. #include <utility>
  20. #include <map>
  21. #include <queue>
  22. #include <stack>
  23. #include <set>
  24. #include <algorithm>
  25. #include "base/core_ops.h"
  26. #include "utils/utils.h"
  27. #include "utils/ms_context.h"
  28. #include "ps/ps_context.h"
  29. #ifdef ENABLE_GE
  30. #include "transform/graph_ir/convert.h"
  31. #endif
  32. namespace mindspore {
  33. const char kMsConvert[] = "ms";
  34. const char kMsVm[] = "vm";
  35. const char kGeVm[] = "ge";
  36. namespace compile {
  37. namespace {
  38. std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) {
  39. auto context_ptr = MsContext::GetInstance();
  40. MS_EXCEPTION_IF_NULL(context_ptr);
  41. std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  42. for (auto &node : nodes) {
  43. if (!node->isa<CNode>()) {
  44. continue;
  45. }
  46. std::string cur_target = GetCNodeTarget(node);
  47. if (cur_target != default_target) {
  48. return cur_target;
  49. }
  50. }
  51. return "";
  52. }
  53. bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node,
  54. std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) {
  55. MS_EXCEPTION_IF_NULL(prior_node);
  56. MS_EXCEPTION_IF_NULL(behind_node);
  57. MS_EXCEPTION_IF_NULL(graph);
  58. auto manager = graph->manager();
  59. MS_EXCEPTION_IF_NULL(manager);
  60. auto &node_users = manager->node_users();
  61. if (prior_node->isa<Parameter>()) {
  62. for (auto &user : node_users[prior_node]) {
  63. auto cnode = user.first->cast<CNodePtr>();
  64. MS_EXCEPTION_IF_NULL(cnode);
  65. if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
  66. prior_nodes->emplace_back(cnode);
  67. }
  68. }
  69. } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) {
  70. prior_nodes->emplace_back(prior_node);
  71. } else {
  72. return false;
  73. }
  74. if (behind_node->isa<Parameter>()) {
  75. for (auto &user : node_users[behind_node]) {
  76. auto cnode = user.first->cast<CNodePtr>();
  77. MS_EXCEPTION_IF_NULL(cnode);
  78. if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
  79. depend_nodes->emplace_back(cnode);
  80. }
  81. }
  82. } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) {
  83. depend_nodes->emplace_back(behind_node);
  84. } else {
  85. return false;
  86. }
  87. return true;
  88. }
  89. void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
  90. std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges,
  91. std::map<AnfNodePtr, size_t> *nodes_ref) {
  92. MS_EXCEPTION_IF_NULL(node);
  93. auto input_cnode = node->cast<CNodePtr>();
  94. MS_EXCEPTION_IF_NULL(input_cnode);
  95. auto prior_node = input_cnode->input(kControlDependPriorIndex);
  96. auto depend_node = input_cnode->input(kControlDependBehindIndex);
  97. MS_EXCEPTION_IF_NULL(prior_node);
  98. MS_EXCEPTION_IF_NULL(depend_node);
  99. auto prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
  100. MS_EXCEPTION_IF_NULL(prim_ptr);
  101. ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
  102. int64_t depend_mode = 0;
  103. if (mode_ptr != nullptr) {
  104. depend_mode = GetValue<int64_t>(mode_ptr);
  105. }
  106. if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) {
  107. return;
  108. }
  109. std::vector<AnfNodePtr> prior_nodes;
  110. std::vector<AnfNodePtr> behind_nodes;
  111. if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) {
  112. return;
  113. }
  114. for (auto &first_node : prior_nodes) {
  115. for (auto &second_node : behind_nodes) {
  116. MS_EXCEPTION_IF_NULL(first_node);
  117. MS_EXCEPTION_IF_NULL(second_node);
  118. auto iter = control_edges->find(second_node);
  119. if (iter == control_edges->end()) {
  120. (void)control_edges->insert(
  121. std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node}));
  122. } else {
  123. iter->second.emplace_back(first_node);
  124. }
  125. auto ref_iter = nodes_ref->find(first_node);
  126. if (ref_iter != nodes_ref->end()) {
  127. ref_iter->second++;
  128. } else {
  129. (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1));
  130. }
  131. }
  132. }
  133. }
  134. void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref,
  135. std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) {
  136. std::queue<AnfNodePtr> queue;
  137. queue.push(graph->get_return());
  138. std::set<AnfNodePtr> visited;
  139. while (!queue.empty()) {
  140. auto &node = queue.front();
  141. queue.pop();
  142. MS_EXCEPTION_IF_NULL(node);
  143. if (!node->isa<CNode>()) {
  144. continue;
  145. }
  146. auto cnode = node->cast<CNodePtr>();
  147. MS_EXCEPTION_IF_NULL(cnode);
  148. for (auto &input : cnode->inputs()) {
  149. if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) {
  150. AddControlEdge(graph, input, control_edges, nodes_ref);
  151. }
  152. auto iter = nodes_ref->find(input);
  153. if (iter != nodes_ref->end()) {
  154. iter->second++;
  155. } else {
  156. (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
  157. }
  158. if (visited.find(input) != visited.end()) {
  159. continue;
  160. }
  161. visited.insert(input);
  162. queue.push(input);
  163. }
  164. }
  165. }
  166. std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
  167. std::vector<AnfNodePtr> result;
  168. std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
  169. std::map<AnfNodePtr, size_t> node_positions;
  170. for (auto &node : nodes) {
  171. if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  172. auto cnode = node->cast<CNodePtr>();
  173. MS_EXCEPTION_IF_NULL(cnode);
  174. auto &inputs = cnode->inputs();
  175. if (inputs.size() < 2) {
  176. MS_LOG(EXCEPTION) << "Invalid get item node";
  177. }
  178. auto &parent = inputs[1];
  179. auto iter = node_positions.find(parent);
  180. if (iter != node_positions.end()) {
  181. size_t position = iter->second;
  182. auto iter_nodes = insert_positions.find(position);
  183. if (iter_nodes != insert_positions.end()) {
  184. iter_nodes->second.push_back(node);
  185. } else {
  186. (void)insert_positions.insert(
  187. std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
  188. }
  189. continue;
  190. }
  191. }
  192. result.emplace_back(node);
  193. node_positions[node] = result.size();
  194. }
  195. size_t insert_num = 0;
  196. for (auto &item : insert_positions) {
  197. size_t position = item.first + insert_num;
  198. (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
  199. insert_num += item.second.size();
  200. }
  201. return result;
  202. }
  203. std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
  204. std::vector<AnfNodePtr> result;
  205. std::stack<AnfNodePtr> to_visit;
  206. std::stack<AnfNodePtr> next_to_visit;
  207. std::map<AnfNodePtr, size_t> nodes_ref;
  208. std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
  209. CalcNodeRefCount(graph, &nodes_ref, &control_edges);
  210. std::string handle_target = default_target;
  211. std::string next_target;
  212. to_visit.push(graph->get_return());
  213. while (!to_visit.empty() || !next_to_visit.empty()) {
  214. if (to_visit.empty()) {
  215. to_visit.swap(next_to_visit);
  216. handle_target = next_target;
  217. }
  218. auto node = to_visit.top();
  219. MS_EXCEPTION_IF_NULL(node);
  220. to_visit.pop();
  221. result.emplace_back(node);
  222. if (!node->isa<CNode>()) {
  223. continue;
  224. }
  225. auto cnode = node->cast<CNodePtr>();
  226. MS_EXCEPTION_IF_NULL(cnode);
  227. auto node_inputs = cnode->inputs();
  228. if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
  229. std::reverse(node_inputs.begin(), node_inputs.end());
  230. }
  231. auto ctrl_inputs = control_edges.find(node);
  232. if (ctrl_inputs != control_edges.end()) {
  233. node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
  234. }
  235. for (auto &input : node_inputs) {
  236. auto iter = nodes_ref.find(input);
  237. if (iter != nodes_ref.end()) {
  238. iter->second--;
  239. if (iter->second != 0) {
  240. continue;
  241. }
  242. }
  243. if (!input->isa<CNode>()) {
  244. to_visit.push(input);
  245. continue;
  246. }
  247. std::string input_target = GetCNodeTarget(input);
  248. if (input_target == handle_target) {
  249. to_visit.push(input);
  250. } else if (next_to_visit.empty() || input_target == next_target) {
  251. next_to_visit.push(input);
  252. next_target = input_target;
  253. } else {
  254. MS_LOG(EXCEPTION) << "Only support two different target";
  255. }
  256. }
  257. }
  258. std::reverse(result.begin(), result.end());
  259. return result;
  260. }
  261. struct GraphNodesDependencyInfo {
  262. std::stack<AnfNodePtr> independent_nodes_;
  263. std::map<AnfNodePtr, size_t> input_num_;
  264. std::map<AnfNodePtr, std::vector<AnfNodePtr>> output_edges_;
  265. };
  266. GraphNodesDependencyInfo GetNodesDependencyInfo(const FuncGraphPtr &graph) {
  267. MS_EXCEPTION_IF_NULL(graph);
  268. GraphNodesDependencyInfo info;
  269. std::stack<AnfNodePtr> to_visit;
  270. std::map<AnfNodePtr, size_t> nodes_ref;
  271. std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
  272. CalcNodeRefCount(graph, &nodes_ref, &control_edges);
  273. to_visit.push(graph->get_return());
  274. while (!to_visit.empty()) {
  275. auto node = to_visit.top();
  276. to_visit.pop();
  277. if (node == nullptr || !node->isa<CNode>()) {
  278. continue;
  279. }
  280. auto cnode = node->cast<CNodePtr>();
  281. MS_EXCEPTION_IF_NULL(cnode);
  282. auto node_inputs = cnode->inputs();
  283. auto ctrl_inputs = control_edges.find(node);
  284. if (ctrl_inputs != control_edges.end()) {
  285. node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
  286. }
  287. bool independent = true;
  288. for (auto &input : node_inputs) {
  289. if (input->isa<CNode>()) {
  290. independent = false;
  291. auto output_edge_iter = info.output_edges_.find(input);
  292. if (output_edge_iter != info.output_edges_.end()) {
  293. auto &edges = output_edge_iter->second;
  294. edges.emplace_back(node);
  295. } else {
  296. info.output_edges_[input] = {node};
  297. }
  298. auto input_num_iter = info.input_num_.find(node);
  299. if (input_num_iter != info.input_num_.end()) {
  300. input_num_iter->second++;
  301. } else {
  302. info.input_num_[node] = 1;
  303. }
  304. }
  305. auto ref_iter = nodes_ref.find(input);
  306. if (ref_iter != nodes_ref.end()) {
  307. ref_iter->second--;
  308. if (ref_iter->second != 0) {
  309. continue;
  310. }
  311. }
  312. to_visit.push(input);
  313. }
  314. if (independent) {
  315. info.independent_nodes_.push(node);
  316. }
  317. }
  318. return info;
  319. }
  320. struct VisitNodesInfo {
  321. std::queue<AnfNodePtr> default_target_nodes_;
  322. std::queue<AnfNodePtr> other_target_nodes_;
  323. std::map<AnfNodePtr, AnfNodePtr> seed_cast_next_node_;
  324. };
  325. VisitNodesInfo GetVisitNodesInfo(const GraphNodesDependencyInfo &dependency_info, const std::string &default_target,
  326. const std::string &other_target) {
  327. VisitNodesInfo result;
  328. auto independent_nodes = dependency_info.independent_nodes_;
  329. while (!independent_nodes.empty()) {
  330. auto seed_node = independent_nodes.top();
  331. independent_nodes.pop();
  332. MS_EXCEPTION_IF_NULL(seed_node);
  333. auto node_target = GetCNodeTarget(seed_node);
  334. if (IsPrimitiveCNode(seed_node, prim::kPrimCast)) {
  335. auto output_edges_iter = dependency_info.output_edges_.find(seed_node);
  336. if (output_edges_iter != dependency_info.output_edges_.end() && output_edges_iter->second.size() == 1) {
  337. auto &cast_next_node = output_edges_iter->second[0];
  338. auto input_num_iter = dependency_info.input_num_.find(cast_next_node);
  339. if (input_num_iter == dependency_info.input_num_.end()) {
  340. MS_LOG(EXCEPTION) << "Node input num not found!";
  341. }
  342. if (input_num_iter->second > 1 && node_target == GetCNodeTarget(cast_next_node)) {
  343. result.seed_cast_next_node_[cast_next_node] = seed_node;
  344. continue;
  345. }
  346. }
  347. }
  348. if (node_target == default_target) {
  349. result.default_target_nodes_.push(seed_node);
  350. } else if (node_target == other_target) {
  351. result.other_target_nodes_.push(seed_node);
  352. } else {
  353. MS_LOG(EXCEPTION) << "Only support two difference targets";
  354. }
  355. }
  356. return result;
  357. }
  358. std::string ParallelSortDecideNextHandleTarget(const std::vector<AnfNodePtr> &output_edges,
  359. const std::string &node_target,
  360. std::map<AnfNodePtr, std::string> *node_input_target_map) {
  361. MS_EXCEPTION_IF_NULL(node_input_target_map);
  362. std::string next_target = node_target;
  363. for (auto &dst_node : output_edges) {
  364. auto input_target_iter = node_input_target_map->find(dst_node);
  365. if (input_target_iter != node_input_target_map->end() && input_target_iter->second != node_target) {
  366. next_target = input_target_iter->second;
  367. break;
  368. }
  369. auto dst_node_target = GetCNodeTarget(dst_node);
  370. if (dst_node_target != node_target) {
  371. next_target = dst_node_target;
  372. break;
  373. }
  374. }
  375. for (auto &dst_node : output_edges) {
  376. (*node_input_target_map)[dst_node] = node_target;
  377. }
  378. return next_target;
  379. }
  380. void ParallelSortVisitNodeEdges(const std::vector<AnfNodePtr> &output_edges, GraphNodesDependencyInfo *dependency_info,
  381. VisitNodesInfo *visit_nodes_info, const std::string &default_target) {
  382. MS_EXCEPTION_IF_NULL(dependency_info);
  383. MS_EXCEPTION_IF_NULL(visit_nodes_info);
  384. for (auto &dst_node : output_edges) {
  385. auto dst_node_target = GetCNodeTarget(dst_node);
  386. auto input_num_iter = dependency_info->input_num_.find(dst_node);
  387. if (input_num_iter == dependency_info->input_num_.end()) {
  388. MS_LOG(EXCEPTION) << "Node input num not found!";
  389. }
  390. input_num_iter->second--;
  391. if (input_num_iter->second == 1 &&
  392. visit_nodes_info->seed_cast_next_node_.find(dst_node) != visit_nodes_info->seed_cast_next_node_.end()) {
  393. input_num_iter->second--;
  394. }
  395. if (input_num_iter->second > 0) {
  396. continue;
  397. }
  398. if (dst_node_target == default_target) {
  399. visit_nodes_info->default_target_nodes_.push(dst_node);
  400. } else {
  401. visit_nodes_info->other_target_nodes_.push(dst_node);
  402. }
  403. }
  404. }
  405. std::vector<AnfNodePtr> ParallelSort(const FuncGraphPtr &graph, const std::string &default_target,
  406. const std::string &other_target) {
  407. MS_EXCEPTION_IF_NULL(graph);
  408. auto dependency_info = GetNodesDependencyInfo(graph);
  409. auto visit_nodes_info = GetVisitNodesInfo(dependency_info, default_target, other_target);
  410. std::vector<AnfNodePtr> result;
  411. std::string handle_target;
  412. if (!visit_nodes_info.default_target_nodes_.empty()) {
  413. handle_target = default_target;
  414. } else {
  415. handle_target = other_target;
  416. }
  417. std::map<AnfNodePtr, std::string> node_input_target_map;
  418. while (!visit_nodes_info.default_target_nodes_.empty() || !visit_nodes_info.other_target_nodes_.empty()) {
  419. AnfNodePtr ready_node;
  420. if ((!visit_nodes_info.default_target_nodes_.empty() && handle_target == default_target) ||
  421. visit_nodes_info.other_target_nodes_.empty()) {
  422. ready_node = visit_nodes_info.default_target_nodes_.front();
  423. visit_nodes_info.default_target_nodes_.pop();
  424. handle_target = default_target;
  425. } else {
  426. ready_node = visit_nodes_info.other_target_nodes_.front();
  427. visit_nodes_info.other_target_nodes_.pop();
  428. handle_target = other_target;
  429. }
  430. MS_EXCEPTION_IF_NULL(ready_node);
  431. auto cast_map_iter = visit_nodes_info.seed_cast_next_node_.find(ready_node);
  432. if (cast_map_iter != visit_nodes_info.seed_cast_next_node_.end()) {
  433. result.emplace_back(cast_map_iter->second);
  434. }
  435. result.emplace_back(ready_node);
  436. auto output_edges_iter = dependency_info.output_edges_.find(ready_node);
  437. if (output_edges_iter == dependency_info.output_edges_.end()) {
  438. continue;
  439. }
  440. auto &output_edges = output_edges_iter->second;
  441. handle_target = ParallelSortDecideNextHandleTarget(output_edges, handle_target, &node_input_target_map);
  442. ParallelSortVisitNodeEdges(output_edges, &dependency_info, &visit_nodes_info, default_target);
  443. }
  444. return result;
  445. }
  446. void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_target,
  447. const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) {
  448. std::stack<AnfNodePtr> to_visit;
  449. std::map<AnfNodePtr, size_t> nodes_ref;
  450. std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
  451. CalcNodeRefCount(graph, &nodes_ref, &control_edges);
  452. to_visit.push(graph->get_return());
  453. while (!to_visit.empty()) {
  454. auto &node = to_visit.top();
  455. MS_EXCEPTION_IF_NULL(node);
  456. to_visit.pop();
  457. if (!node->isa<CNode>()) {
  458. continue;
  459. }
  460. auto cnode = node->cast<CNodePtr>();
  461. MS_EXCEPTION_IF_NULL(cnode);
  462. auto node_inputs = cnode->inputs();
  463. auto ctrl_inputs = control_edges.find(node);
  464. if (ctrl_inputs != control_edges.end()) {
  465. node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
  466. }
  467. GraphSegmentPtr node_segment{nullptr};
  468. if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
  469. auto node_iter = node_to_segment.find(node);
  470. if (node_iter != node_to_segment.end()) {
  471. node_segment = node_iter->second;
  472. }
  473. }
  474. for (auto &input : node_inputs) {
  475. if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) {
  476. GraphSegmentPtr input_segment{nullptr};
  477. auto input_iter = node_to_segment.find(input);
  478. if (input_iter != node_to_segment.end()) {
  479. input_segment = input_iter->second;
  480. }
  481. if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) {
  482. node_segment->AddPreSegment(input_segment);
  483. }
  484. }
  485. auto ref_iter = nodes_ref.find(input);
  486. if (ref_iter != nodes_ref.end()) {
  487. ref_iter->second--;
  488. if (ref_iter->second != 0) {
  489. continue;
  490. }
  491. }
  492. to_visit.push(input);
  493. }
  494. }
  495. }
  496. bool IsSubGraph(const AnfNodePtr &node) {
  497. MS_EXCEPTION_IF_NULL(node);
  498. if (node->isa<CNode>()) {
  499. auto cnode = node->cast<CNodePtr>();
  500. auto &inputs = cnode->inputs();
  501. if (inputs.empty()) {
  502. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  503. }
  504. AnfNodePtr fn = inputs[0];
  505. if (!IsValueNode<Primitive>(fn)) {
  506. return false;
  507. }
  508. auto node_prim = GetValueNode<PrimitivePtr>(fn);
  509. if (node_prim->name() == prim::kPrimPartial->name()) {
  510. return true;
  511. }
  512. } else if (IsValueNode<FuncGraph>(node)) {
  513. return true;
  514. }
  515. return false;
  516. }
  517. bool IsShapeDynamic(const abstract::ShapePtr &shape) {
  518. MS_EXCEPTION_IF_NULL(shape);
  519. return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
  520. }
  521. bool IsNodeOutputDynamicShape(const CNodePtr &node) {
  522. MS_EXCEPTION_IF_NULL(node);
  523. auto base_shape = node->Shape();
  524. if (base_shape == nullptr) {
  525. MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
  526. return false;
  527. }
  528. if (base_shape->isa<abstract::Shape>()) {
  529. if (IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
  530. return true;
  531. }
  532. } else if (base_shape->isa<abstract::TupleShape>()) {
  533. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  534. MS_EXCEPTION_IF_NULL(tuple_shape);
  535. for (size_t i = 0; i < tuple_shape->size(); i++) {
  536. auto b_shape = (*tuple_shape)[i];
  537. if (!b_shape->isa<abstract::Shape>()) {
  538. continue;
  539. }
  540. if (IsShapeDynamic(b_shape->cast<abstract::ShapePtr>())) {
  541. return true;
  542. }
  543. }
  544. }
  545. return false;
  546. }
  547. void AddSegment(const std::vector<AnfNodePtr> &nodes, std::vector<GraphSegmentPtr> *segments,
  548. std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
  549. MS_EXCEPTION_IF_NULL(segments);
  550. MS_EXCEPTION_IF_NULL(node_to_segment);
  551. auto segment = std::make_shared<GraphSegment>(nodes, false);
  552. segments->emplace_back(segment);
  553. for (auto &node : nodes) {
  554. (*node_to_segment)[node] = segment;
  555. }
  556. }
  557. struct SplitDynamicNodesHelper {
  558. void AddNode(const AnfNodePtr &node, bool is_dynamic_shape) {
  559. if (is_dynamic_shape) {
  560. pre_dynamic_nodes.emplace_back(node);
  561. pre_dynamic_nodes_set.insert(node);
  562. } else {
  563. pre_common_nodes.emplace_back(node);
  564. pre_common_nodes_set.insert(node);
  565. }
  566. pre_nodes.emplace_back(node);
  567. }
  568. void AddSegments(std::vector<GraphSegmentPtr> *segments, std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
  569. if (pre_nodes.size() < merge_node_threshold) {
  570. AddSegment(pre_nodes, segments, node_to_segment);
  571. } else {
  572. if (!pre_common_nodes.empty()) {
  573. AddSegment(pre_common_nodes, segments, node_to_segment);
  574. }
  575. if (!pre_dynamic_nodes.empty()) {
  576. AddSegment(pre_dynamic_nodes, segments, node_to_segment);
  577. }
  578. }
  579. pre_common_nodes.clear();
  580. pre_common_nodes_set.clear();
  581. pre_dynamic_nodes.clear();
  582. pre_dynamic_nodes_set.clear();
  583. pre_nodes.clear();
  584. }
  585. std::vector<AnfNodePtr> pre_nodes;
  586. std::vector<AnfNodePtr> pre_dynamic_nodes;
  587. std::vector<AnfNodePtr> pre_common_nodes;
  588. std::set<AnfNodePtr> pre_common_nodes_set;
  589. std::set<AnfNodePtr> pre_dynamic_nodes_set;
  590. size_t merge_node_threshold = 6;
  591. };
  592. void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
  593. std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment,
  594. const std::set<AnfNodePtr> &dynamic_nodes_set) {
  595. SplitDynamicNodesHelper helper;
  596. bool is_last_node_dynamic = false;
  597. for (auto &node : segment_nodes) {
  598. auto cnode = node->cast<CNodePtr>();
  599. MS_EXCEPTION_IF_NULL(cnode);
  600. if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
  601. helper.AddNode(node, is_last_node_dynamic);
  602. continue;
  603. }
  604. auto &inputs = cnode->inputs();
  605. bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end();
  606. bool depend_common_node = false;
  607. bool depend_dynamic_node = false;
  608. for (size_t i = 1; i < inputs.size(); ++i) {
  609. if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) {
  610. has_dynamic_shape = true;
  611. }
  612. if (helper.pre_common_nodes_set.find(inputs[i]) != helper.pre_common_nodes_set.end()) {
  613. depend_common_node = true;
  614. }
  615. if (helper.pre_dynamic_nodes_set.find(inputs[i]) != helper.pre_dynamic_nodes_set.end()) {
  616. depend_dynamic_node = true;
  617. }
  618. }
  619. if (has_dynamic_shape) {
  620. if (depend_common_node) {
  621. helper.AddSegments(segments, node_to_segment);
  622. }
  623. is_last_node_dynamic = true;
  624. } else {
  625. if (depend_dynamic_node) {
  626. helper.AddSegments(segments, node_to_segment);
  627. }
  628. is_last_node_dynamic = false;
  629. }
  630. helper.AddNode(node, is_last_node_dynamic);
  631. }
  632. helper.AddSegments(segments, node_to_segment);
  633. }
  634. void NodesToSegments(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
  635. std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
  636. if (segment_nodes.empty()) {
  637. return;
  638. }
  639. auto segment_target = GetCNodeTarget(segment_nodes[0]);
  640. if (segment_target != kAscendDevice) {
  641. AddSegment(segment_nodes, segments, node_to_segment);
  642. return;
  643. }
  644. MS_EXCEPTION_IF_NULL(segments);
  645. MS_EXCEPTION_IF_NULL(node_to_segment);
  646. std::set<AnfNodePtr> dynamic_nodes_set;
  647. for (auto &node : segment_nodes) {
  648. auto cnode = node->cast<CNodePtr>();
  649. if (IsNodeOutputDynamicShape(cnode)) {
  650. (void)dynamic_nodes_set.insert(node);
  651. }
  652. }
  653. if (dynamic_nodes_set.empty()) {
  654. AddSegment(segment_nodes, segments, node_to_segment);
  655. return;
  656. }
  657. SplitDynamicNodeSegment(segment_nodes, segments, node_to_segment, dynamic_nodes_set);
  658. }
  659. } // namespace
  660. GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name)
  661. : cut_list_(cut_list), backend_name_(backend_name) {}
  662. bool GraphPartition::IsCut(const AnfNodePtr &node) {
  663. MS_EXCEPTION_IF_NULL(node);
  664. if (node->isa<CNode>()) {
  665. auto cnode = node->cast<CNodePtr>();
  666. auto &inputs = cnode->inputs();
  667. if (inputs.empty()) {
  668. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  669. }
  670. AnfNodePtr fn = inputs[0];
  671. if (IsValueNode<FuncGraph>(fn)) {
  672. auto fg = GetValueNode<FuncGraphPtr>(fn);
  673. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  674. return false;
  675. }
  676. }
  677. if (!IsValueNode<Primitive>(fn)) {
  678. return true;
  679. }
  680. auto node_prim = GetValueNode<PrimitivePtr>(fn);
  681. for (auto &prim : cut_list_) {
  682. MS_EXCEPTION_IF_NULL(prim);
  683. if (prim->name() == node_prim->name()) {
  684. if (prim->name() == prim::kPrimBpropCut->name()) {
  685. auto ms_context = MsContext::GetInstance();
  686. MS_EXCEPTION_IF_NULL(ms_context);
  687. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
  688. }
  689. if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
  690. if (inputs.size() < 2) {
  691. return false;
  692. }
  693. auto ret = IsSubGraph(inputs[1]);
  694. return ret;
  695. }
  696. return true;
  697. }
  698. }
  699. #ifdef ENABLE_GE
  700. if (backend_name_ == kGeVm) {
  701. auto name = GetCNodeFuncName(cnode);
  702. auto adpt = transform::DfGraphConvertor::FindAdapter(name);
  703. if (adpt == nullptr) {
  704. return true;
  705. }
  706. }
  707. #endif
  708. }
  709. return false;
  710. }
  711. std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph) {
  712. MS_EXCEPTION_IF_NULL(graph);
  713. auto nodes = TopoSort(graph->get_return());
  714. MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
  715. bool contain_multi_target = ContainMultiTarget(nodes);
  716. auto context_ptr = MsContext::GetInstance();
  717. MS_EXCEPTION_IF_NULL(context_ptr);
  718. std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  719. if (contain_multi_target) {
  720. if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
  721. auto other_target = GetOtherTarget(nodes);
  722. nodes = ParallelSort(graph, default_target, other_target);
  723. } else {
  724. nodes = SplitSort(graph, default_target);
  725. }
  726. nodes = OptimizeGetItemOrder(nodes);
  727. }
  728. std::vector<GraphSegmentPtr> segments;
  729. std::vector<AnfNodePtr> segment_nodes;
  730. std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment;
  731. std::string last_target;
  732. for (auto &node : nodes) {
  733. MS_EXCEPTION_IF_NULL(node);
  734. if (IsCut(node)) {
  735. NodesToSegments(segment_nodes, &segments, &node_to_segment);
  736. segment_nodes.clear();
  737. segment_nodes.emplace_back(node);
  738. auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
  739. segments.push_back(segment);
  740. segment_nodes.clear();
  741. } else if (node->isa<CNode>()) {
  742. if (contain_multi_target) {
  743. std::string cur_target = GetCNodeTarget(node);
  744. if (cur_target != last_target && !last_target.empty()) {
  745. NodesToSegments(segment_nodes, &segments, &node_to_segment);
  746. segment_nodes.clear();
  747. }
  748. last_target = cur_target;
  749. }
  750. segment_nodes.emplace_back(node);
  751. }
  752. }
  753. MS_LOG(DEBUG) << "Segment size:" << segments.size();
  754. if (contain_multi_target) {
  755. AddSegmentDependency(graph, default_target, node_to_segment);
  756. }
  757. return segments;
  758. }
  759. } // namespace compile
  760. } // namespace mindspore