You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_partition.cc 26 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "vm/graph_partition.h"
  17. #include <string>
  18. #include <functional>
  19. #include <utility>
  20. #include <map>
  21. #include <queue>
  22. #include <stack>
  23. #include <set>
  24. #include <algorithm>
  25. #include "base/core_ops.h"
  26. #include "utils/utils.h"
  27. #include "utils/ms_context.h"
  28. #include "ps/ps_context.h"
  29. #include "utils/anf_utils.h"
  30. #ifdef ENABLE_GE
  31. #include "transform/graph_ir/convert.h"
  32. #endif
  33. namespace mindspore {
  34. const char kMsConvert[] = "ms";
  35. const char kMsVm[] = "vm";
  36. const char kGeVm[] = "ge";
  37. namespace compile {
  38. namespace {
  39. std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) {
  40. auto context_ptr = MsContext::GetInstance();
  41. MS_EXCEPTION_IF_NULL(context_ptr);
  42. std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  43. for (auto &node : nodes) {
  44. MS_EXCEPTION_IF_NULL(node);
  45. if (!node->isa<CNode>()) {
  46. continue;
  47. }
  48. std::string cur_target = GetCNodeTarget(node);
  49. if (cur_target != default_target) {
  50. return cur_target;
  51. }
  52. }
  53. return "";
  54. }
  55. void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
  56. MS_EXCEPTION_IF_NULL(graph);
  57. MS_EXCEPTION_IF_NULL(nodes_ref);
  58. std::queue<AnfNodePtr> queue;
  59. queue.push(graph->get_return());
  60. std::set<AnfNodePtr> visited;
  61. while (!queue.empty()) {
  62. auto node = queue.front();
  63. queue.pop();
  64. MS_EXCEPTION_IF_NULL(node);
  65. if (!node->isa<CNode>()) {
  66. continue;
  67. }
  68. auto cnode = node->cast<CNodePtr>();
  69. MS_EXCEPTION_IF_NULL(cnode);
  70. for (auto &input : cnode->inputs()) {
  71. auto iter = nodes_ref->find(input);
  72. if (iter != nodes_ref->end()) {
  73. iter->second++;
  74. } else {
  75. (void)nodes_ref->emplace(input, 1UL);
  76. }
  77. if (visited.find(input) != visited.end()) {
  78. continue;
  79. }
  80. visited.insert(input);
  81. queue.push(input);
  82. }
  83. }
  84. }
  85. std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes, const PrimitivePtr &reorder_prim) {
  86. std::vector<AnfNodePtr> result;
  87. std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
  88. std::map<AnfNodePtr, size_t> node_positions;
  89. auto add_insert_position = [&insert_positions, &node_positions](const AnfNodePtr &node, const AnfNodePtr &parent) {
  90. if (parent == nullptr) {
  91. return false;
  92. }
  93. auto iter = node_positions.find(parent);
  94. if (iter != node_positions.end()) {
  95. size_t position = iter->second;
  96. auto iter_nodes = insert_positions.find(position);
  97. if (iter_nodes != insert_positions.end()) {
  98. iter_nodes->second.push_back(node);
  99. } else {
  100. (void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
  101. }
  102. return true;
  103. }
  104. return false;
  105. };
  106. for (auto &node : nodes) {
  107. MS_EXCEPTION_IF_NULL(node);
  108. if (IsPrimitiveCNode(node, reorder_prim)) {
  109. auto cnode = node->cast<CNodePtr>();
  110. MS_EXCEPTION_IF_NULL(cnode);
  111. auto &inputs = cnode->inputs();
  112. AnfNodePtr parent = nullptr;
  113. const size_t depend_input_size = 2;
  114. if (reorder_prim == prim::kPrimDepend && inputs.size() == depend_input_size + 1 && !inputs[1]->isa<CNode>()) {
  115. parent = inputs[depend_input_size];
  116. } else if (reorder_prim == prim::kPrimTupleGetItem && inputs.size() > 1) {
  117. parent = inputs[1];
  118. }
  119. if (add_insert_position(node, parent)) {
  120. continue;
  121. }
  122. }
  123. result.emplace_back(node);
  124. node_positions[node] = result.size();
  125. }
  126. size_t insert_num = 0;
  127. for (auto &item : insert_positions) {
  128. auto position = SizeToLong(item.first + insert_num);
  129. (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
  130. insert_num += item.second.size();
  131. }
  132. return result;
  133. }
  134. std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
  135. std::vector<AnfNodePtr> *result) {
  136. MS_EXCEPTION_IF_NULL(node);
  137. MS_EXCEPTION_IF_NULL(nodes_ref);
  138. MS_EXCEPTION_IF_NULL(result);
  139. auto cnode = node->cast<CNodePtr>();
  140. MS_EXCEPTION_IF_NULL(cnode);
  141. auto node_inputs = cnode->inputs();
  142. if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
  143. std::reverse(node_inputs.begin(), node_inputs.end());
  144. return node_inputs;
  145. }
  146. std::vector<AnfNodePtr> extend_inputs;
  147. for (auto &input : node_inputs) {
  148. MS_EXCEPTION_IF_NULL(input);
  149. if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
  150. auto iter = nodes_ref->find(input);
  151. if (iter != nodes_ref->end() && iter->second == 1) {
  152. iter->second--;
  153. result->emplace_back(input);
  154. auto partial_cnode = input->cast<CNodePtr>();
  155. MS_EXCEPTION_IF_NULL(partial_cnode);
  156. auto partial_inputs = partial_cnode->inputs();
  157. std::reverse(partial_inputs.begin(), partial_inputs.end());
  158. (void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
  159. continue;
  160. }
  161. }
  162. extend_inputs.emplace_back(input);
  163. }
  164. return extend_inputs;
  165. }
  166. std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
  167. MS_EXCEPTION_IF_NULL(graph);
  168. std::vector<AnfNodePtr> result;
  169. std::stack<AnfNodePtr> to_visit;
  170. std::stack<AnfNodePtr> next_to_visit;
  171. std::map<AnfNodePtr, size_t> nodes_ref;
  172. CalcNodeRefCount(graph, &nodes_ref);
  173. std::string handle_target = default_target;
  174. std::string next_target;
  175. to_visit.push(graph->get_return());
  176. while (!to_visit.empty() || !next_to_visit.empty()) {
  177. if (to_visit.empty()) {
  178. to_visit.swap(next_to_visit);
  179. handle_target = next_target;
  180. }
  181. auto node = to_visit.top();
  182. MS_EXCEPTION_IF_NULL(node);
  183. to_visit.pop();
  184. result.emplace_back(node);
  185. if (!node->isa<CNode>()) {
  186. continue;
  187. }
  188. auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
  189. for (auto &input : next_nodes) {
  190. MS_EXCEPTION_IF_NULL(input);
  191. auto iter = nodes_ref.find(input);
  192. if (iter != nodes_ref.end()) {
  193. iter->second--;
  194. if (iter->second != 0) {
  195. continue;
  196. }
  197. }
  198. if (!input->isa<CNode>()) {
  199. to_visit.push(input);
  200. continue;
  201. }
  202. std::string input_target = GetCNodeTarget(input);
  203. if (input_target == handle_target) {
  204. to_visit.push(input);
  205. } else if (next_to_visit.empty() || input_target == next_target) {
  206. next_to_visit.push(input);
  207. next_target = input_target;
  208. } else {
  209. MS_LOG(EXCEPTION) << "Only support two different target";
  210. }
  211. }
  212. }
  213. std::reverse(result.begin(), result.end());
  214. return result;
  215. }
  216. struct GraphNodesDependencyInfo {
  217. std::stack<AnfNodePtr> independent_nodes_;
  218. std::map<AnfNodePtr, size_t> input_num_;
  219. std::map<AnfNodePtr, std::vector<AnfNodePtr>> output_edges_;
  220. };
  221. GraphNodesDependencyInfo GetNodesDependencyInfo(const FuncGraphPtr &graph) {
  222. MS_EXCEPTION_IF_NULL(graph);
  223. GraphNodesDependencyInfo info;
  224. std::stack<AnfNodePtr> to_visit;
  225. std::map<AnfNodePtr, size_t> nodes_ref;
  226. CalcNodeRefCount(graph, &nodes_ref);
  227. to_visit.push(graph->get_return());
  228. while (!to_visit.empty()) {
  229. auto node = to_visit.top();
  230. to_visit.pop();
  231. if (node == nullptr || !node->isa<CNode>()) {
  232. continue;
  233. }
  234. auto cnode = node->cast<CNodePtr>();
  235. MS_EXCEPTION_IF_NULL(cnode);
  236. auto node_inputs = cnode->inputs();
  237. bool independent = true;
  238. for (auto &input : node_inputs) {
  239. MS_EXCEPTION_IF_NULL(input);
  240. if (input->isa<CNode>()) {
  241. independent = false;
  242. auto output_edge_iter = info.output_edges_.find(input);
  243. if (output_edge_iter != info.output_edges_.end()) {
  244. auto &edges = output_edge_iter->second;
  245. edges.emplace_back(node);
  246. } else {
  247. info.output_edges_[input] = {node};
  248. }
  249. auto input_num_iter = info.input_num_.find(node);
  250. if (input_num_iter != info.input_num_.end()) {
  251. input_num_iter->second++;
  252. } else {
  253. info.input_num_[node] = 1;
  254. }
  255. }
  256. auto ref_iter = nodes_ref.find(input);
  257. if (ref_iter != nodes_ref.end()) {
  258. ref_iter->second--;
  259. if (ref_iter->second != 0) {
  260. continue;
  261. }
  262. }
  263. to_visit.push(input);
  264. }
  265. if (independent) {
  266. info.independent_nodes_.push(node);
  267. }
  268. }
  269. return info;
  270. }
  271. struct VisitNodesInfo {
  272. std::queue<AnfNodePtr> default_target_nodes_;
  273. std::queue<AnfNodePtr> other_target_nodes_;
  274. std::map<AnfNodePtr, AnfNodePtr> seed_cast_next_node_;
  275. };
  276. VisitNodesInfo GetVisitNodesInfo(const GraphNodesDependencyInfo &dependency_info, const std::string &default_target,
  277. const std::string &other_target) {
  278. VisitNodesInfo result;
  279. auto independent_nodes = dependency_info.independent_nodes_;
  280. while (!independent_nodes.empty()) {
  281. auto seed_node = independent_nodes.top();
  282. independent_nodes.pop();
  283. MS_EXCEPTION_IF_NULL(seed_node);
  284. auto node_target = GetCNodeTarget(seed_node);
  285. if (IsPrimitiveCNode(seed_node, prim::kPrimCast)) {
  286. auto output_edges_iter = dependency_info.output_edges_.find(seed_node);
  287. if (output_edges_iter != dependency_info.output_edges_.end() && output_edges_iter->second.size() == 1) {
  288. auto &cast_next_node = output_edges_iter->second[0];
  289. auto input_num_iter = dependency_info.input_num_.find(cast_next_node);
  290. if (input_num_iter == dependency_info.input_num_.end()) {
  291. MS_LOG(EXCEPTION) << "Node input num not found!";
  292. }
  293. if (input_num_iter->second > 1 && node_target == GetCNodeTarget(cast_next_node)) {
  294. result.seed_cast_next_node_[cast_next_node] = seed_node;
  295. continue;
  296. }
  297. }
  298. }
  299. if (node_target == default_target) {
  300. result.default_target_nodes_.push(seed_node);
  301. } else if (node_target == other_target) {
  302. result.other_target_nodes_.push(seed_node);
  303. } else {
  304. MS_LOG(EXCEPTION) << "Only support two difference targets";
  305. }
  306. }
  307. return result;
  308. }
  309. std::string ParallelSortDecideNextHandleTarget(const std::vector<AnfNodePtr> &output_edges,
  310. const std::string &node_target,
  311. std::map<AnfNodePtr, std::string> *node_input_target_map) {
  312. MS_EXCEPTION_IF_NULL(node_input_target_map);
  313. std::string next_target = node_target;
  314. for (auto &dst_node : output_edges) {
  315. auto input_target_iter = node_input_target_map->find(dst_node);
  316. if (input_target_iter != node_input_target_map->end() && input_target_iter->second != node_target) {
  317. next_target = input_target_iter->second;
  318. break;
  319. }
  320. auto dst_node_target = GetCNodeTarget(dst_node);
  321. if (dst_node_target != node_target) {
  322. next_target = dst_node_target;
  323. break;
  324. }
  325. }
  326. for (auto &dst_node : output_edges) {
  327. (*node_input_target_map)[dst_node] = node_target;
  328. }
  329. return next_target;
  330. }
  331. void ParallelSortVisitNodeEdges(const std::vector<AnfNodePtr> &output_edges, GraphNodesDependencyInfo *dependency_info,
  332. VisitNodesInfo *visit_nodes_info, const std::string &default_target) {
  333. MS_EXCEPTION_IF_NULL(dependency_info);
  334. MS_EXCEPTION_IF_NULL(visit_nodes_info);
  335. for (auto &dst_node : output_edges) {
  336. auto dst_node_target = GetCNodeTarget(dst_node);
  337. auto input_num_iter = dependency_info->input_num_.find(dst_node);
  338. if (input_num_iter == dependency_info->input_num_.end()) {
  339. MS_LOG(EXCEPTION) << "Node input num not found!";
  340. }
  341. input_num_iter->second--;
  342. if (input_num_iter->second == 1 &&
  343. visit_nodes_info->seed_cast_next_node_.find(dst_node) != visit_nodes_info->seed_cast_next_node_.end()) {
  344. input_num_iter->second--;
  345. }
  346. if (input_num_iter->second > 0) {
  347. continue;
  348. }
  349. if (dst_node_target == default_target) {
  350. visit_nodes_info->default_target_nodes_.push(dst_node);
  351. } else {
  352. visit_nodes_info->other_target_nodes_.push(dst_node);
  353. }
  354. }
  355. }
  356. std::vector<AnfNodePtr> ParallelSort(const FuncGraphPtr &graph, const std::string &default_target,
  357. const std::string &other_target) {
  358. MS_EXCEPTION_IF_NULL(graph);
  359. auto dependency_info = GetNodesDependencyInfo(graph);
  360. auto visit_nodes_info = GetVisitNodesInfo(dependency_info, default_target, other_target);
  361. std::vector<AnfNodePtr> result;
  362. std::string handle_target;
  363. if (!visit_nodes_info.default_target_nodes_.empty()) {
  364. handle_target = default_target;
  365. } else {
  366. handle_target = other_target;
  367. }
  368. std::map<AnfNodePtr, std::string> node_input_target_map;
  369. while (!visit_nodes_info.default_target_nodes_.empty() || !visit_nodes_info.other_target_nodes_.empty()) {
  370. AnfNodePtr ready_node;
  371. if ((!visit_nodes_info.default_target_nodes_.empty() && handle_target == default_target) ||
  372. visit_nodes_info.other_target_nodes_.empty()) {
  373. ready_node = visit_nodes_info.default_target_nodes_.front();
  374. visit_nodes_info.default_target_nodes_.pop();
  375. handle_target = default_target;
  376. } else {
  377. ready_node = visit_nodes_info.other_target_nodes_.front();
  378. visit_nodes_info.other_target_nodes_.pop();
  379. handle_target = other_target;
  380. }
  381. MS_EXCEPTION_IF_NULL(ready_node);
  382. auto cast_map_iter = visit_nodes_info.seed_cast_next_node_.find(ready_node);
  383. if (cast_map_iter != visit_nodes_info.seed_cast_next_node_.end()) {
  384. result.emplace_back(cast_map_iter->second);
  385. }
  386. result.emplace_back(ready_node);
  387. auto output_edges_iter = dependency_info.output_edges_.find(ready_node);
  388. if (output_edges_iter == dependency_info.output_edges_.end()) {
  389. continue;
  390. }
  391. auto &output_edges = output_edges_iter->second;
  392. handle_target = ParallelSortDecideNextHandleTarget(output_edges, handle_target, &node_input_target_map);
  393. ParallelSortVisitNodeEdges(output_edges, &dependency_info, &visit_nodes_info, default_target);
  394. }
  395. return result;
  396. }
  397. void AddSegmentDependency(const FuncGraphPtr &graph, const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) {
  398. MS_EXCEPTION_IF_NULL(graph);
  399. std::stack<AnfNodePtr> to_visit;
  400. std::map<AnfNodePtr, size_t> nodes_ref;
  401. CalcNodeRefCount(graph, &nodes_ref);
  402. to_visit.push(graph->get_return());
  403. while (!to_visit.empty()) {
  404. auto &node = to_visit.top();
  405. MS_EXCEPTION_IF_NULL(node);
  406. to_visit.pop();
  407. if (!node->isa<CNode>()) {
  408. continue;
  409. }
  410. auto cnode = node->cast<CNodePtr>();
  411. MS_EXCEPTION_IF_NULL(cnode);
  412. auto node_inputs = cnode->inputs();
  413. GraphSegmentPtr node_segment{nullptr};
  414. auto node_iter = node_to_segment.find(node);
  415. if (node_iter != node_to_segment.end()) {
  416. node_segment = node_iter->second;
  417. }
  418. for (auto &input : node_inputs) {
  419. if (node_segment != nullptr && !node_segment->is_cut_ && input != nullptr && input->isa<CNode>()) {
  420. GraphSegmentPtr input_segment{nullptr};
  421. auto input_iter = node_to_segment.find(input);
  422. if (input_iter != node_to_segment.end()) {
  423. input_segment = input_iter->second;
  424. }
  425. if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) {
  426. node_segment->AddPreSegment(input_segment);
  427. }
  428. }
  429. auto ref_iter = nodes_ref.find(input);
  430. if (ref_iter != nodes_ref.end()) {
  431. ref_iter->second--;
  432. if (ref_iter->second != 0) {
  433. continue;
  434. }
  435. }
  436. to_visit.push(input);
  437. }
  438. }
  439. }
  440. void RemoveUselessDependency(const std::vector<GraphSegmentPtr> *segments) {
  441. MS_EXCEPTION_IF_NULL(segments);
  442. for (auto &segment : *segments) {
  443. MS_EXCEPTION_IF_NULL(segment);
  444. if (segment->is_cut_) {
  445. continue;
  446. }
  447. bool total_virtual_node = true;
  448. for (auto &node : segment->nodes_) {
  449. if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
  450. IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
  451. IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
  452. IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
  453. IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  454. continue;
  455. }
  456. total_virtual_node = false;
  457. break;
  458. }
  459. if (total_virtual_node) {
  460. segment->pre_segments_.clear();
  461. }
  462. }
  463. }
  464. bool IsSubGraph(const AnfNodePtr &node) {
  465. MS_EXCEPTION_IF_NULL(node);
  466. if (node->isa<CNode>()) {
  467. auto cnode = node->cast<CNodePtr>();
  468. auto &inputs = cnode->inputs();
  469. if (inputs.empty()) {
  470. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  471. }
  472. AnfNodePtr fn = inputs[0];
  473. if (!IsValueNode<Primitive>(fn)) {
  474. return false;
  475. }
  476. auto node_prim = GetValueNode<PrimitivePtr>(fn);
  477. if (node_prim->name() == prim::kPrimPartial->name()) {
  478. return true;
  479. }
  480. } else if (IsValueNode<FuncGraph>(node)) {
  481. return true;
  482. }
  483. return false;
  484. }
  485. void AddSegment(const std::vector<AnfNodePtr> &nodes, std::vector<GraphSegmentPtr> *segments,
  486. std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
  487. MS_EXCEPTION_IF_NULL(segments);
  488. MS_EXCEPTION_IF_NULL(node_to_segment);
  489. auto segment = std::make_shared<GraphSegment>(nodes, false);
  490. segments->emplace_back(segment);
  491. for (auto &node : nodes) {
  492. (*node_to_segment)[node] = segment;
  493. }
  494. }
  495. struct SplitDynamicNodesHelper {
  496. void AddNode(const AnfNodePtr &node, bool is_dynamic_shape) {
  497. if (is_dynamic_shape) {
  498. pre_dynamic_nodes.emplace_back(node);
  499. pre_dynamic_nodes_set.insert(node);
  500. } else {
  501. pre_common_nodes.emplace_back(node);
  502. pre_common_nodes_set.insert(node);
  503. }
  504. pre_nodes.emplace_back(node);
  505. }
  506. void AddSegments(std::vector<GraphSegmentPtr> *segments, std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
  507. MS_EXCEPTION_IF_NULL(segments);
  508. MS_EXCEPTION_IF_NULL(node_to_segment);
  509. if (pre_nodes.size() < merge_node_threshold) {
  510. AddSegment(pre_nodes, segments, node_to_segment);
  511. } else {
  512. if (!pre_common_nodes.empty()) {
  513. AddSegment(pre_common_nodes, segments, node_to_segment);
  514. }
  515. if (!pre_dynamic_nodes.empty()) {
  516. AddSegment(pre_dynamic_nodes, segments, node_to_segment);
  517. }
  518. }
  519. pre_common_nodes.clear();
  520. pre_common_nodes_set.clear();
  521. pre_dynamic_nodes.clear();
  522. pre_dynamic_nodes_set.clear();
  523. pre_nodes.clear();
  524. }
  525. std::vector<AnfNodePtr> pre_nodes;
  526. std::vector<AnfNodePtr> pre_dynamic_nodes;
  527. std::vector<AnfNodePtr> pre_common_nodes;
  528. std::set<AnfNodePtr> pre_common_nodes_set;
  529. std::set<AnfNodePtr> pre_dynamic_nodes_set;
  530. size_t merge_node_threshold = 6;
  531. };
  532. void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
  533. std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment,
  534. const std::set<AnfNodePtr> &dynamic_nodes_set) {
  535. SplitDynamicNodesHelper helper;
  536. for (auto &node : segment_nodes) {
  537. MS_EXCEPTION_IF_NULL(node);
  538. auto cnode = node->cast<CNodePtr>();
  539. MS_EXCEPTION_IF_NULL(cnode);
  540. auto &inputs = cnode->inputs();
  541. bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end();
  542. bool depend_common_node = false;
  543. bool depend_dynamic_node = false;
  544. bool is_last_node_dynamic = false;
  545. for (size_t i = 1; i < inputs.size(); ++i) {
  546. if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) {
  547. has_dynamic_shape = true;
  548. }
  549. if (helper.pre_common_nodes_set.find(inputs[i]) != helper.pre_common_nodes_set.end()) {
  550. depend_common_node = true;
  551. }
  552. if (helper.pre_dynamic_nodes_set.find(inputs[i]) != helper.pre_dynamic_nodes_set.end()) {
  553. depend_dynamic_node = true;
  554. }
  555. }
  556. if (has_dynamic_shape) {
  557. if (depend_common_node) {
  558. helper.AddSegments(segments, node_to_segment);
  559. }
  560. is_last_node_dynamic = true;
  561. } else {
  562. if (depend_dynamic_node) {
  563. helper.AddSegments(segments, node_to_segment);
  564. }
  565. is_last_node_dynamic = false;
  566. }
  567. helper.AddNode(node, is_last_node_dynamic);
  568. }
  569. helper.AddSegments(segments, node_to_segment);
  570. }
  571. void NodesToSegments(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
  572. std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
  573. if (segment_nodes.empty()) {
  574. return;
  575. }
  576. auto segment_target = GetCNodeTarget(segment_nodes[0]);
  577. if (segment_target != kAscendDevice) {
  578. AddSegment(segment_nodes, segments, node_to_segment);
  579. return;
  580. }
  581. MS_EXCEPTION_IF_NULL(segments);
  582. MS_EXCEPTION_IF_NULL(node_to_segment);
  583. std::set<AnfNodePtr> dynamic_nodes_set;
  584. for (auto &node : segment_nodes) {
  585. MS_EXCEPTION_IF_NULL(node);
  586. auto cnode = node->cast<CNodePtr>();
  587. if (AnfUtils::IsNodeOutputDynamicShape(cnode)) {
  588. (void)dynamic_nodes_set.insert(node);
  589. }
  590. }
  591. if (dynamic_nodes_set.empty()) {
  592. AddSegment(segment_nodes, segments, node_to_segment);
  593. return;
  594. }
  595. SplitDynamicNodeSegment(segment_nodes, segments, node_to_segment, dynamic_nodes_set);
  596. }
  597. } // namespace
  598. GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name)
  599. : cut_list_(cut_list), backend_name_(backend_name) {}
  600. bool GraphPartition::IsCut(const AnfNodePtr &node) {
  601. MS_EXCEPTION_IF_NULL(node);
  602. if (node->isa<CNode>()) {
  603. auto cnode = node->cast<CNodePtr>();
  604. auto &inputs = cnode->inputs();
  605. if (inputs.empty()) {
  606. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  607. }
  608. AnfNodePtr fn = inputs[0];
  609. if (!IsValueNode<Primitive>(fn)) {
  610. return true;
  611. }
  612. auto node_prim = GetValueNode<PrimitivePtr>(fn);
  613. for (auto &prim : cut_list_) {
  614. MS_EXCEPTION_IF_NULL(prim);
  615. if (prim->name() == node_prim->name()) {
  616. if (prim->name() == prim::kPrimBpropCut->name()) {
  617. auto ms_context = MsContext::GetInstance();
  618. MS_EXCEPTION_IF_NULL(ms_context);
  619. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
  620. }
  621. if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
  622. if (inputs.size() <= 1) {
  623. return false;
  624. }
  625. auto ret = IsSubGraph(inputs[1]);
  626. return ret;
  627. }
  628. return true;
  629. }
  630. }
  631. #ifdef ENABLE_GE
  632. if (backend_name_ == kGeVm) {
  633. auto name = GetCNodeFuncName(cnode);
  634. auto adpt = transform::DfGraphConvertor::FindAdapter(name);
  635. if (adpt == nullptr) {
  636. return true;
  637. }
  638. }
  639. #endif
  640. }
  641. return false;
  642. }
  643. std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph, bool *multi_target) {
  644. MS_EXCEPTION_IF_NULL(graph);
  645. auto nodes = TopoSort(graph->get_return());
  646. MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
  647. bool contain_multi_target = ContainMultiTarget(nodes);
  648. if (multi_target != nullptr) {
  649. *multi_target = contain_multi_target;
  650. }
  651. auto context_ptr = MsContext::GetInstance();
  652. MS_EXCEPTION_IF_NULL(context_ptr);
  653. auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
  654. std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  655. if (contain_multi_target || !enable_loop_sink) {
  656. if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
  657. auto other_target = GetOtherTarget(nodes);
  658. nodes = ParallelSort(graph, default_target, other_target);
  659. } else {
  660. nodes = SplitSort(graph, default_target);
  661. }
  662. nodes = ReorderVirtualNode(nodes, prim::kPrimTupleGetItem);
  663. nodes = ReorderVirtualNode(nodes, prim::kPrimDepend);
  664. }
  665. std::vector<GraphSegmentPtr> segments;
  666. std::vector<AnfNodePtr> segment_nodes;
  667. std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment;
  668. std::string last_target;
  669. for (auto &node : nodes) {
  670. MS_EXCEPTION_IF_NULL(node);
  671. if (IsCut(node)) {
  672. NodesToSegments(segment_nodes, &segments, &node_to_segment);
  673. segment_nodes.clear();
  674. segment_nodes.emplace_back(node);
  675. auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
  676. segments.push_back(segment);
  677. segment_nodes.clear();
  678. } else if (node->isa<CNode>()) {
  679. if (contain_multi_target) {
  680. std::string cur_target = GetCNodeTarget(node);
  681. if (cur_target != last_target && !last_target.empty()) {
  682. NodesToSegments(segment_nodes, &segments, &node_to_segment);
  683. segment_nodes.clear();
  684. }
  685. last_target = cur_target;
  686. }
  687. segment_nodes.emplace_back(node);
  688. }
  689. }
  690. MS_LOG(DEBUG) << "Segment size:" << segments.size();
  691. if (contain_multi_target) {
  692. AddSegmentDependency(graph, node_to_segment);
  693. RemoveUselessDependency(&segments);
  694. }
  695. return segments;
  696. }
  697. } // namespace compile
  698. } // namespace mindspore