You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel.cc 51 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/parallel/step_auto_parallel.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include <unordered_set>
  28. #include "base/core_ops.h"
  29. #include "frontend/optimizer/opt.h"
  30. #include "frontend/optimizer/optimizer.h"
  31. #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
  32. #include "frontend/parallel/auto_parallel/edge_costmodel.h"
  33. #include "frontend/parallel/auto_parallel/graph_costmodel.h"
  34. #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
  35. #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
  36. #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
  37. #include "frontend/parallel/context.h"
  38. #include "frontend/parallel/graph_util/node_info.h"
  39. #include "frontend/parallel/graph_util/graph_info.h"
  40. #include "frontend/parallel/ops_info/reshape_info.h"
  41. #include "frontend/parallel/ops_info/tmp_identity_info.h"
  42. #include "frontend/parallel/step_parallel.h"
  43. #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
  44. #include "ir/anf.h"
  45. #include "ir/param_info.h"
  46. #include "ir/tensor.h"
  47. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  48. #include "ps/util.h"
  49. #endif
  50. namespace mindspore {
  51. namespace parallel {
  52. bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
  53. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  54. if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
  55. return false;
  56. }
  57. #endif
  58. MS_EXCEPTION_IF_NULL(root);
  59. MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
  60. std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
  61. // assume no change to graph
  62. bool changes = false;
  63. // control whether use model_parallel mode
  64. if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
  65. root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
  66. return changes;
  67. }
  68. // check whether strategy_search_mode is valid
  69. std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
  70. if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
  71. // Setting searching mode: dynamic programming as default.
  72. strategy_search_mode = DYNAMIC_PROGRAMMING;
  73. MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
  74. }
  75. struct timeval start_time, end_time;
  76. (void)gettimeofday(&start_time, nullptr);
  77. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  78. draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
  79. }
  80. MS_LOG(INFO) << "Now entering step auto parallel";
  81. TOTAL_OPS = 0;
  82. AnfNodePtr ret = root->get_return();
  83. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  84. if (ParallelInit() != SUCCESS) {
  85. MS_LOG(EXCEPTION) << "Parallel init failed";
  86. }
  87. // mark the forward cnodes, parallel only care these nodes
  88. MarkForwardCNode(root);
  89. if (FindCommunicationOp(all_nodes)) {
  90. MS_LOG(EXCEPTION) << "The graph contain communication op";
  91. }
  92. // search parallelization strategy
  93. if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
  94. if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
  95. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
  96. }
  97. } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
  98. if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
  99. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
  100. }
  101. } else {
  102. MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
  103. }
  104. (void)gettimeofday(&end_time, nullptr);
  105. uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  106. time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  107. MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
  108. root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
  109. return changes;
  110. }
  111. bool IsElementWiseOperator(const std::string &op_name) {
  112. // clang-format off
  113. static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
  114. SOFTMAX, LOG_SOFTMAX, RELU,
  115. SQRT, CAST, POW,
  116. EXP, LOG, COS,
  117. ACOS, LOGICALNOT, NEG,
  118. SQUARE, SIGMOID, ABS,
  119. ACOSH, ASIN, ASINH,
  120. ATAN, ATANH, CEIL,
  121. COSH, EXPM1, LOG1P,
  122. SIN, SINH, TAN,
  123. RSQRT, RECIPROCAL, INV,
  124. ROUND, FLOOR, SIGN,
  125. ERF, ERFC, ZEROSLIKE,
  126. ONESLIKE, BESSELI0E, MOD,
  127. ASSIGN, ASSIGN_ADD, ATAN2,
  128. DIVNONAN, LOGICALAND, ELU,
  129. LOGICALOR, RELU6, SOFTPLUS,
  130. SOFTSIGN, LESS, LESSEQUAL,
  131. BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL,
  132. REPEAT_ELEMENTS};
  133. // clang-format on
  134. auto iter = elementwise_op.find(op_name);
  135. return (iter != elementwise_op.end());
  136. }
  137. bool IsSplittableOperator(const std::string &op_name) {
  138. // clang-format off
  139. static const std::set<std::string> splittable_op =
  140. {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
  141. FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
  142. REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
  143. MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
  144. LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
  145. STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
  146. SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
  147. EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
  148. EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
  149. BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
  150. SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
  151. UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE,
  152. UNSORTED_SEGMENT_MAX};
  153. // clang-format on
  154. auto iter = splittable_op.find(op_name);
  155. return (iter != splittable_op.end());
  156. }
  157. bool IsAutoParallelCareNode(const CNodePtr &cnode) {
  158. MS_EXCEPTION_IF_NULL(cnode);
  159. ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
  160. if (prim_node == nullptr) {
  161. return false;
  162. }
  163. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
  164. if (prim == nullptr) {
  165. return false;
  166. }
  167. bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
  168. if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) {
  169. MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
  170. } else if (prim->name() == CAST) {
  171. if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
  172. // Do not care CASTs from optimizer
  173. return false;
  174. }
  175. return true;
  176. }
  177. return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
  178. }
  179. // Recording the operators appearing in a for-loop.
  180. // Currently, we assume that the operators in different for-loops are identical, and their traversal
  181. // orderings are also identical.
  182. // Therefore, we create OperatorInfo objects for the operators in a loop (say, loop-3), and reuse them in
  183. // the rest of loops (loop-2, loop-1 and loop-0)
  184. std::set<std::string> ops_in_a_loop_;
  185. // Whether two operators are in different loops; if it is true, then return true.
  186. // If at least one of the two operators is not in the loop, then return false.
  187. // If two operators are in the same loop, the return false.
  188. bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cnode) {
  189. auto a_op_info = a_cnode->user_data<OperatorInfo>();
  190. MS_EXCEPTION_IF_NULL(a_op_info);
  191. auto b_op_info = b_cnode->user_data<OperatorInfo>();
  192. MS_EXCEPTION_IF_NULL(b_op_info);
  193. if ((ops_in_a_loop_.find(a_op_info->name()) == ops_in_a_loop_.end()) ||
  194. (ops_in_a_loop_.find(b_op_info->name()) == ops_in_a_loop_.end())) {
  195. return false;
  196. }
  197. size_t a_loop_index = 0, b_loop_index = 0;
  198. const auto &a_fullname = a_cnode->fullname_with_scope();
  199. if (!GetLoopIndexFromCNode(a_cnode, &a_loop_index)) {
  200. MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << a_fullname << " was not included in the set.";
  201. }
  202. const auto &b_fullname = b_cnode->fullname_with_scope();
  203. if (!GetLoopIndexFromCNode(b_cnode, &b_loop_index)) {
  204. MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << b_fullname << " was not included in the set.";
  205. }
  206. if (a_loop_index == b_loop_index) {
  207. return false;
  208. }
  209. return true;
  210. }
  211. void InitCostGraph() {
  212. if (entire_costgraph == nullptr) {
  213. entire_costgraph = std::make_shared<CostGraph>();
  214. }
  215. entire_costgraph->SetDeviceMemoryAndCostParameter();
  216. entire_costgraph->Init();
  217. }
  218. OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
  219. StrategyMap *stra_map) {
  220. MS_EXCEPTION_IF_NULL(prim);
  221. MS_EXCEPTION_IF_NULL(cnode);
  222. auto attrs = prim->attrs();
  223. std::vector<Shapes> shape_list = ExtractShape(cnode);
  224. if (shape_list.empty()) {
  225. MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
  226. }
  227. // Create an OperatorInfo instance
  228. OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
  229. MS_EXCEPTION_IF_NULL(operator_info);
  230. // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
  231. std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
  232. if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
  233. MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
  234. return nullptr;
  235. }
  236. // Set the data type for inputs and outputs of this OperatorInfo
  237. auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
  238. auto outputs_type = ExtractOutputTypeByNode(cnode);
  239. std::vector<size_t> outputs_type_length;
  240. outputs_type_length.reserve(outputs_type.size());
  241. std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
  242. GetLengthOfDataType);
  243. if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
  244. MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
  245. return nullptr;
  246. }
  247. if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
  248. MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
  249. return nullptr;
  250. }
  251. // When the 'inputs' contains numerical values for some operators, these values should be extracted from
  252. // ANF graph
  253. auto &inputs = cnode->inputs();
  254. std::vector<ValuePtr> input_value;
  255. for (size_t index = 1; index < inputs.size(); ++index) {
  256. if (inputs[index]->isa<ValueNode>()) {
  257. input_value.push_back(GetValueNode(inputs[index]));
  258. } else {
  259. input_value.emplace_back(nullptr);
  260. }
  261. }
  262. operator_info->set_input_value(input_value);
  263. operator_info->set_outputs_dtype(cnode->Type());
  264. operator_info->set_cnode(cnode);
  265. // key of strategy map
  266. std::string strategy_key_name = "";
  267. auto param_names = NodeParameterName(cnode);
  268. if (!param_names.empty()) {
  269. strategy_key_name = prim->name() + "_" + param_names[0].first;
  270. }
  271. bool load_strategy_from_ckpt =
  272. StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
  273. // If no strategy has been configured for this operator, then candidate strategies are generated for
  274. // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
  275. // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
  276. if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) {
  277. // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
  278. // BatchParallelInfo operator
  279. operator_info->ComputeBatchSplitFlagList();
  280. if (operator_info->GenerateStrategies(0) != SUCCESS) {
  281. MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
  282. return nullptr;
  283. }
  284. // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
  285. auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
  286. if (approximation) {
  287. operator_info->ApproximateStrategies();
  288. MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
  289. }
  290. } else {
  291. // In this case, the configured strategy should be extracted to help setting cost
  292. StrategyPtr strategyPtr;
  293. if (is_last_nodes) {
  294. bool full_batch = ParallelContext::GetInstance()->full_batch();
  295. strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
  296. if (full_batch) {
  297. SetLastNodeStrategy(strategyPtr);
  298. }
  299. } else if (StrategyFound(attrs)) {
  300. strategyPtr = parallel::ExtractStrategy(attrs);
  301. } else {
  302. strategyPtr = (*stra_map)[strategy_key_name];
  303. }
  304. if (strategyPtr != nullptr) {
  305. if (prim->name() == RESHAPE) {
  306. MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
  307. }
  308. // Set cost for this configured strategy
  309. if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
  310. MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
  311. } else if (FULLY_USE_DEVICES) {
  312. // If configured to fully use devices, then checking for the user-specified strategy
  313. int64_t used_devices = operator_info->used_devices();
  314. MS_EXCEPTION_IF_NULL(g_device_manager);
  315. auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
  316. // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
  317. if (used_devices == 1) {
  318. return operator_info;
  319. }
  320. // 'used_devices == -1' means that 'used_devices_' is not set
  321. if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) {
  322. MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
  323. << "but the specified strategy uses device: " << used_devices
  324. << ", total devices: " << total_device_num;
  325. }
  326. }
  327. }
  328. }
  329. return operator_info;
  330. }
  331. // Using CNode's UniqueIds to construct nodes
  332. Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  333. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  334. // The map from CNode's UniqueId to its operatorInfo
  335. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  336. // The operator_infos in a loop
  337. std::vector<OperatorInfoPtr> operators_in_forloop;
  338. // Key: i-th loop; Value: index of 'operators_in_forloop'
  339. std::map<size_t, size_t> loop_to_ops;
  340. // extract strategy from checkpoint for multi-train
  341. StrategyMap stra_map;
  342. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  343. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  344. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  345. }
  346. }
  347. std::vector<std::string> last_forward_node_ids;
  348. if (!root->has_flag(TRAINING)) {
  349. FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
  350. MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
  351. }
  352. for (auto &node : all_nodes) {
  353. // NOTE: we only care about splittable Primitive operators
  354. auto cnode = node->cast<CNodePtr>();
  355. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  356. if (bool_result) {
  357. continue;
  358. }
  359. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  360. if (!IsAutoParallelCareNode(cnode)) {
  361. // Needed by rec_parser
  362. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  363. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  364. if (prev_cnode != nullptr) {
  365. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  366. }
  367. }
  368. continue;
  369. }
  370. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  371. MS_EXCEPTION_IF_NULL(prim);
  372. auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
  373. if (search_cnode == from_cnode_to_info.end()) {
  374. size_t loop_index = 0;
  375. bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
  376. if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
  377. const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
  378. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  379. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  380. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  381. if (is_find_wrong) {
  382. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  383. << " does not match the Prim: " << prim->name()
  384. << ". The fullname_with_scope: " << cnode->fullname_with_scope();
  385. }
  386. loop_to_ops[loop_index]++;
  387. cnode->set_user_data<OperatorInfo>(current_op_ptr);
  388. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  389. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  390. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  391. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  392. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
  393. continue;
  394. }
  395. bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
  396. last_forward_node_ids.end();
  397. auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
  398. if (operator_info == nullptr) {
  399. return FAILED;
  400. }
  401. // Needed by rec_parser
  402. operator_info->set_type(prim->name());
  403. operator_info->set_last_node_flag(is_last_nodes);
  404. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  405. entire_costgraph->AddOperator(operator_info);
  406. cnode->set_user_data<OperatorInfo>(operator_info);
  407. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  408. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  409. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  410. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  411. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info));
  412. if (DP_ALGO_SINGLE_LOOP && is_in_loop) {
  413. operators_in_forloop.push_back(operator_info);
  414. ops_in_a_loop_.insert(operator_info->name());
  415. loop_to_ops[loop_index]++;
  416. }
  417. // Needed by rec_parser
  418. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  419. } else {
  420. // Two CNODEs' UniqueIds should not be equal
  421. MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
  422. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  423. << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
  424. }
  425. }
  426. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  427. return SUCCESS;
  428. }
  429. // Using CNode's UniqueIdThroughCopys to construct nodes
  430. Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  431. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  432. // The map from CNode's UniqueIdThroughCopy to its operatorInfo
  433. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  434. // The operator_infos in a loop
  435. std::vector<OperatorInfoPtr> operators_in_forloop;
  436. // Key: i-th loop; Value: index of 'operators_in_forloop'
  437. std::map<size_t, size_t> loop_to_ops;
  438. // extract strategy from checkpoint for multi-train
  439. StrategyMap stra_map;
  440. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  441. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  442. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  443. }
  444. }
  445. std::vector<std::string> last_forward_node_ids;
  446. if (!root->has_flag(TRAINING)) {
  447. FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
  448. MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
  449. }
  450. for (auto &node : all_nodes) {
  451. // NOTE: we only care about splittable Primitive operators
  452. auto cnode = node->cast<CNodePtr>();
  453. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  454. if (bool_result) {
  455. continue;
  456. }
  457. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  458. if (!IsAutoParallelCareNode(cnode)) {
  459. // Needed by rec_parser
  460. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  461. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  462. if (prev_cnode != nullptr) {
  463. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  464. }
  465. }
  466. continue;
  467. }
  468. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  469. // Find the operatorInfo if it exists
  470. auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
  471. if (search_cnode == from_cnode_to_info.end()) {
  472. size_t loop_index = 0;
  473. bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
  474. if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
  475. const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
  476. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  477. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  478. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  479. if (is_find_wrong) {
  480. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  481. << " does not match the Prim: " << prim->name()
  482. << ". The fullname_with_scope: " << cnode->fullname_with_scope();
  483. }
  484. loop_to_ops[loop_index]++;
  485. cnode->set_user_data<OperatorInfo>(current_op_ptr);
  486. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  487. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  488. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  489. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  490. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), current_op_ptr));
  491. continue;
  492. }
  493. // In this case, the corresponding OperatorInfo is not created, create the new one.
  494. bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
  495. last_forward_node_ids.end();
  496. auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
  497. if (operator_info == nullptr) {
  498. return FAILED;
  499. }
  500. // Needed by rec_parser
  501. operator_info->set_type(prim->name());
  502. operator_info->set_last_node_flag(is_last_nodes);
  503. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  504. entire_costgraph->AddOperator(operator_info);
  505. cnode->set_user_data<OperatorInfo>(operator_info);
  506. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  507. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  508. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  509. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  510. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  511. if (DP_ALGO_SINGLE_LOOP && is_in_loop) {
  512. operators_in_forloop.push_back(operator_info);
  513. ops_in_a_loop_.insert(operator_info->name());
  514. loop_to_ops[loop_index]++;
  515. }
  516. // Needed by rec_parser
  517. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  518. } else {
  519. auto current_op_ptr = search_cnode->second;
  520. if (current_op_ptr == nullptr) {
  521. MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
  522. } else {
  523. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  524. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  525. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  526. if (is_find_wrong) {
  527. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  528. << " does not match the Prim: " << prim->name();
  529. }
  530. // Needed by rec_parser
  531. ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
  532. cnode->set_user_data<OperatorInfo>(current_op_ptr);
  533. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  534. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  535. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  536. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  537. }
  538. }
  539. }
  540. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  541. return SUCCESS;
  542. }
  543. void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
  544. // Step 2
  545. MS_LOG(INFO) << "Constructing edges for cost graph begins.";
  546. for (auto &node : all_nodes) {
  547. auto cnode = node->cast<CNodePtr>();
  548. bool bool_result_cnode = (cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0));
  549. if (bool_result_cnode) {
  550. continue;
  551. }
  552. auto &inputs = cnode->inputs();
  553. ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
  554. if (!IsAutoParallelCareNode(cnode)) {
  555. continue;
  556. }
  557. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  558. size_t edge_count = 0;
  559. auto node_op_info = cnode->user_data<OperatorInfo>();
  560. for (size_t i = 1; i < inputs.size(); ++i) {
  561. auto prev_cnode = inputs[i]->cast<CNodePtr>();
  562. bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  563. if (bool_result_prev_cnode) {
  564. continue;
  565. }
  566. ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  567. PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  568. size_t output_index = 0;
  569. bool bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
  570. (prev_prim->name() == DEPEND);
  571. while (bool_result) {
  572. if (IsAutoParallelCareNode(prev_cnode)) {
  573. auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
  574. std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
  575. // If the edge between these two operators already has been added, then the edge will not be added again.
  576. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
  577. break;
  578. }
  579. EdgePtr edge_ptr;
  580. MS_LOG(INFO) << "Creating edge: " << edge_name;
  581. if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) {
  582. MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope()
  583. << ", cnode_fullname: " << cnode->fullname_with_scope();
  584. MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
  585. break;
  586. }
  587. bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
  588. (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
  589. if (follow_strategy) {
  590. // Redistribution in not allowed on the edge.
  591. // Elementwise operators have the same strategy as their previous operators.
  592. edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true);
  593. } else {
  594. edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false);
  595. }
  596. // Init costs for this edge
  597. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  598. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  599. }
  600. node_op_info->AddPrevEdge(edge_ptr);
  601. prev_op_info->AddSuccEdge(edge_ptr);
  602. entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
  603. MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and "
  604. << node_op_info->name();
  605. edge_count++;
  606. break;
  607. } else if (prev_prim->name() == prim::kTupleGetItem) {
  608. // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
  609. // this 'tuple_getitem'
  610. MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
  611. output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
  612. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  613. bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  614. if (bool_result_tuple) {
  615. break;
  616. }
  617. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  618. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  619. if (!IsAutoParallelCareNode(prev_cnode)) {
  620. MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
  621. }
  622. MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
  623. << "and creating an edge between the Operator before "
  624. << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
  625. } else if (prev_prim->name() == DEPEND) {
  626. // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
  627. // this 'depend'
  628. MS_LOG(INFO) << "Jumping the 'depend' operator.";
  629. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  630. bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  631. if (bool_result_depend) {
  632. break;
  633. }
  634. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  635. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  636. MS_LOG(INFO) << "Jumped the 'depend' operator, "
  637. << "and creating an edge between the Operator before "
  638. << "'depend' and the Operator after 'depend'.";
  639. }
  640. bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
  641. (prev_prim->name() == DEPEND);
  642. }
  643. }
  644. MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
  645. }
  646. // If 'approximation' is enabled, the edges need to be checked have effective costs.
  647. auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
  648. if (approximation) {
  649. entire_costgraph->CheckApproximateCostGraphEdges();
  650. }
  651. MS_LOG(INFO) << "Constructing edges for cost graph ends.";
  652. }
  653. void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
  654. // Step 3
  655. for (auto &node : all_nodes) {
  656. ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode);
  657. auto parameter_name = parameter_users_info.first;
  658. auto target_parameter = parameter_users_info.second.first;
  659. auto target_set = parameter_users_info.second.second;
  660. if (target_set.size() <= 1) {
  661. continue;
  662. }
  663. // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
  664. std::set<std::string> target_without_duplicate;
  665. for (auto &target : target_set) {
  666. auto target_cnode = target.first->cast<CNodePtr>();
  667. auto input_index = target.second;
  668. (void)target_without_duplicate.insert(std::to_string(input_index) +
  669. target_cnode->user_data<OperatorInfo>()->name());
  670. }
  671. if (target_without_duplicate.size() <= 1) {
  672. continue;
  673. }
  674. // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
  675. OperatorInfoPtr tmp_identity_ptr;
  676. bool new_identity = false;
  677. std::string tmp_identity_name;
  678. auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
  679. if (returned_identity != nullptr) {
  680. // In this case, the TmpIdentityInfo instance has already been created
  681. new_identity = false;
  682. tmp_identity_ptr = returned_identity;
  683. tmp_identity_name = tmp_identity_ptr->name();
  684. } else {
  685. // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
  686. new_identity = true;
  687. // 1) extract input shape from this Parameter
  688. MS_EXCEPTION_IF_NULL(target_parameter);
  689. AbstractBasePtr abstract = target_parameter->abstract();
  690. if (abstract == nullptr) {
  691. MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
  692. }
  693. auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
  694. if (input_shape == nullptr) {
  695. MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
  696. }
  697. Shape shape = input_shape->shape();
  698. Shapes inputs_shape = {shape};
  699. Shapes outputs_shape = {shape};
  700. // 2) init the attr
  701. std::unordered_map<std::string, ValuePtr> attr = {};
  702. // Create the TmpIdentity instance
  703. tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
  704. tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
  705. TOTAL_OPS++;
  706. tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
  707. // Set the parameter and type lengths for inputs and outputs
  708. std::vector<bool> is_parameter;
  709. auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
  710. MS_EXCEPTION_IF_NULL(casted_target_parameter);
  711. is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
  712. if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
  713. MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
  714. }
  715. auto node_type = target_parameter->Type();
  716. if (node_type->isa<mindspore::TensorType>()) {
  717. auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
  718. std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
  719. if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
  720. MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
  721. }
  722. } else {
  723. MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
  724. }
  725. // Generate strategies for this TmpIdentityInfo instance;
  726. if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
  727. MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
  728. }
  729. }
  730. // A flag recording whether new edges have been created or not
  731. bool add_identity_edge = false;
  732. // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
  733. for (auto &target : target_set) {
  734. auto target_cnode = target.first->cast<CNodePtr>();
  735. auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
  736. auto input_index = target.second;
  737. auto target_op_info = target_cnode->user_data<OperatorInfo>();
  738. std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
  739. // If the edge between these two operators already has been added, then the edge will not be added again.
  740. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, LongToSize(input_index - 1))) {
  741. continue;
  742. }
  743. std::shared_ptr<Edge> edge_ptr =
  744. std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
  745. // If 'approximation' is enabled, the edges need to be checked have effective costs.
  746. auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
  747. if (approximation) {
  748. target_op_info->ExactStrategiesAndRelatedEdges();
  749. }
  750. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  751. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  752. }
  753. target_op_info->AddPrevEdge(edge_ptr);
  754. tmp_identity_ptr->AddSuccEdge(edge_ptr);
  755. entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
  756. MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
  757. << target_op_info->name();
  758. add_identity_edge = true;
  759. }
  760. if (new_identity && add_identity_edge) {
  761. // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
  762. entire_costgraph->AddOperator(tmp_identity_ptr);
  763. }
  764. }
  765. }
  766. void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
  767. std::unordered_set<std::string> op_cache;
  768. for (auto node : all_nodes) {
  769. auto cnode = node->cast<CNodePtr>();
  770. if (!FindReshape(cnode, &op_cache)) {
  771. continue;
  772. }
  773. MS_ASSERT(cnode->inputs().size() == 3);
  774. // get previous node's strategy_cost_
  775. auto pre_node = cnode->input(1);
  776. if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) {
  777. pre_node = pre_node->cast<CNodePtr>()->input(1);
  778. }
  779. int64_t out_index = 0;
  780. OperatorInfoPtr pre_operator_info;
  781. std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
  782. auto operator_info = cnode->user_data<OperatorInfo>();
  783. if (pre_node->isa<Parameter>()) {
  784. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  785. reshape_info->SetCostForReshapeWithParameter();
  786. pre_operator_info = reshape_info;
  787. pre_stra_costs = reshape_info->strategy_cost();
  788. } else {
  789. if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
  790. MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
  791. }
  792. pre_stra_costs = pre_operator_info->strategy_cost();
  793. }
  794. // get next node's strategy_cost_
  795. int64_t in_index = 0;
  796. OperatorInfoPtr next_operator_info;
  797. std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
  798. bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index);
  799. if (!find_next_node) {
  800. MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
  801. }
  802. // set input_layout and output_layout for reshape.
  803. // init reshape and set cost for each input_layout and output_layout.
  804. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  805. reshape_info->set_pre_operator_name(pre_operator_info->name());
  806. reshape_info->set_pre_operator_index(out_index);
  807. if (find_next_node) {
  808. next_stra_costs = next_operator_info->strategy_cost();
  809. reshape_info->set_next_operator_name(next_operator_info->name());
  810. reshape_info->set_next_operator_index(in_index);
  811. }
  812. bool is_prev_param = pre_node->isa<Parameter>();
  813. if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
  814. SUCCESS) {
  815. MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!";
  816. }
  817. }
  818. }
  819. Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  820. // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
  821. // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
  822. // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
  823. // for each OperatorInfo;
  824. // Step 1.1: Deal with 'Reshape':
  825. // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
  826. // layout as its output layout.
  827. // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
  828. // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
  829. // for each edge, based on the strategies of two OperatorInfos;
  830. // Step 3: Augment the costgraph:
  831. // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
  832. // operator for this Parameter, and add an edge for the use of this Parameter by each
  833. // subsequent operator;
  834. // Step 3.1: Calculate memory usage:
  835. // note the memory usage calculation is different in training phase and inference phase.
  836. // Step 4: Run the Dynamic Programming algorithm:
  837. // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
  838. // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
  839. // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
  840. // runs on each of them.
  841. //
  842. // OUTPUT: the determined strategy for each operator.
  843. InitCostGraph();
  844. // Step 1
  845. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  846. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  847. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  848. << entire_costgraph->GetOperators().size() << " operators.";
  849. } else {
  850. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  851. }
  852. } else {
  853. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  854. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  855. << entire_costgraph->GetOperators().size() << " operators.";
  856. } else {
  857. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  858. }
  859. }
  860. // Step 1.1
  861. ReshapeCostCompute(all_nodes);
  862. // Step 2
  863. ConstructCostGraphEdges(all_nodes);
  864. MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  865. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  866. // Step 3: Augment the costgraph.
  867. AugmentCostGraph(all_nodes);
  868. auto num_ops = entire_costgraph->GetOperators().size();
  869. SetOpsNumToExecutor(num_ops);
  870. auto num_edges = entire_costgraph->GetNumEdges();
  871. MS_LOG(INFO) << "After the augmenting procedure, there are " << num_ops << " operators, and " << num_edges
  872. << " edges.";
  873. // Step 3.1: Calculate the memory usage
  874. if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
  875. MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
  876. }
  877. // Step 4: run DP algorithm on the costgraph.
  878. if (GetStrategy(entire_costgraph) != SUCCESS) {
  879. MS_LOG(ERROR) << "Strategy search for cost-graph fails";
  880. return FAILED;
  881. }
  882. MS_LOG(INFO) << "Searching strategy succeeded.";
  883. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  884. MS_LOG(INFO) << "Init selected strategy succeeded.";
  885. } else {
  886. MS_LOG(EXCEPTION) << "Init selected strategy failed.";
  887. }
  888. // print the selected strategy
  889. for (auto &op : entire_costgraph->GetOperators()) {
  890. StrategyPtr s_strategy = op->selected_strategy();
  891. MS_LOG(INFO) << op->name() << " : The strategy is:";
  892. PrintStrategy(s_strategy);
  893. }
  894. ops_in_a_loop_.clear();
  895. return SUCCESS;
  896. }
  897. std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
  898. std::vector<std::vector<std::string>> input_tensor_names) {
  899. for (size_t j = 0; j < input_tensor_names.size(); j++) {
  900. for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
  901. if (it->first == input_tensor_names[j][k]) {
  902. input_tensor_names[j][k] = it->second;
  903. break;
  904. }
  905. }
  906. }
  907. return input_tensor_names;
  908. }
  909. CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
  910. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  911. if (prim->name() == prim::kTupleGetItem || prim->name() == DEPEND) {
  912. auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
  913. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  914. return nullptr;
  915. }
  916. auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  917. while (prev_prim->name() == prim::kTupleGetItem || prev_prim->name() == DEPEND) {
  918. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  919. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  920. return nullptr;
  921. }
  922. prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  923. }
  924. return prev_cnode;
  925. }
  926. return nullptr;
  927. }
  928. void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
  929. size_t iter_ops = 0;
  930. for (auto op : entire_costgraph->GetOperators()) {
  931. if (op->name() == name) {
  932. break;
  933. }
  934. iter_ops = iter_ops + 1;
  935. }
  936. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  937. for (size_t i = 0; i < input_tensor_names.size(); i++) {
  938. for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
  939. if (input_tensor_names[i][j] == uniqueid) {
  940. input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
  941. }
  942. }
  943. }
  944. entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
  945. }
  946. Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  947. InitCostGraph();
  948. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  949. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  950. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  951. << entire_costgraph->GetOperators().size() << " operators.";
  952. } else {
  953. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  954. }
  955. } else {
  956. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  957. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  958. << entire_costgraph->GetOperators().size() << " operators.";
  959. } else {
  960. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  961. }
  962. }
  963. ReshapeCostCompute(all_nodes);
  964. auto ops = entire_costgraph->GetOperators();
  965. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  966. auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
  967. for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
  968. input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
  969. }
  970. std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
  971. std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
  972. std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
  973. graph = EliminateGraph(graph, eli_list, index_list);
  974. size_t num_device = g_device_manager->DeviceNum();
  975. double device_memory = entire_costgraph->GetDeviceMemory();
  976. if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
  977. MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
  978. } else {
  979. MS_LOG(ERROR) << "PartitionForAllDevices failed.";
  980. return FAILED;
  981. }
  982. bool is_training = true;
  983. if (!root->has_flag(TRAINING)) {
  984. is_training = false;
  985. }
  986. GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training);
  987. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  988. MS_LOG(INFO) << "Init selected strategy succeeded.";
  989. } else {
  990. MS_LOG(ERROR) << "Init selected strategy failed.";
  991. return FAILED;
  992. }
  993. // print the selected strategy
  994. for (auto &op : entire_costgraph->GetOperators()) {
  995. StrategyPtr s_strategy = op->selected_strategy();
  996. MS_LOG(INFO) << op->name() << " : The strategy is:";
  997. PrintStrategy(s_strategy);
  998. }
  999. return SUCCESS;
  1000. }
  1001. } // namespace parallel
  1002. } // namespace mindspore