You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel.cc 51 kB

6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/parallel/step_auto_parallel.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include <unordered_set>
  28. #include "base/core_ops.h"
  29. #include "frontend/optimizer/opt.h"
  30. #include "frontend/optimizer/optimizer.h"
  31. #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
  32. #include "frontend/parallel/auto_parallel/edge_costmodel.h"
  33. #include "frontend/parallel/auto_parallel/graph_costmodel.h"
  34. #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
  35. #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
  36. #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
  37. #include "frontend/parallel/context.h"
  38. #include "frontend/parallel/graph_util/node_info.h"
  39. #include "frontend/parallel/graph_util/graph_info.h"
  40. #include "frontend/parallel/ops_info/reshape_info.h"
  41. #include "frontend/parallel/ops_info/tmp_identity_info.h"
  42. #include "frontend/parallel/step_parallel.h"
  43. #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
  44. #include "ir/anf.h"
  45. #include "ir/param_info.h"
  46. #include "ir/tensor.h"
  47. #if (ENABLE_CPU && !_WIN32)
  48. #include "ps/util.h"
  49. #endif
  50. namespace mindspore {
  51. namespace parallel {
  52. bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
  53. #if (ENABLE_CPU && !_WIN32)
  54. if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
  55. return false;
  56. }
  57. #endif
  58. MS_EXCEPTION_IF_NULL(root);
  59. MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
  60. std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
  61. // assume no change to graph
  62. bool changes = false;
  63. // control whether use model_parallel mode
  64. if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
  65. root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
  66. return changes;
  67. }
  68. // check whether strategy_search_mode is valid
  69. std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
  70. if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
  71. // Setting searching mode: dynamic programming as default.
  72. strategy_search_mode = DYNAMIC_PROGRAMMING;
  73. MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
  74. }
  75. struct timeval start_time, end_time;
  76. (void)gettimeofday(&start_time, nullptr);
  77. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  78. draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
  79. }
  80. MS_LOG(INFO) << "Now entering step auto parallel";
  81. TOTAL_OPS = 0;
  82. AnfNodePtr ret = root->get_return();
  83. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  84. if (ParallelInit() != SUCCESS) {
  85. MS_LOG(EXCEPTION) << "Parallel init failed";
  86. }
  87. // mark the forward cnodes, parallel only care these nodes
  88. MarkForwardCNode(root);
  89. if (!root->has_flag(TRAINING)) {
  90. InsertVirtualOutput(root, all_nodes);
  91. AnfNodePtr ret_after = root->get_return();
  92. MS_EXCEPTION_IF_NULL(ret_after);
  93. all_nodes = DeepScopedGraphSearch(ret_after);
  94. }
  95. if (FindCommunicationOp(all_nodes)) {
  96. MS_LOG(EXCEPTION) << "The graph contain communication op";
  97. }
  98. // search parallelization strategy
  99. if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
  100. if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
  101. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
  102. }
  103. } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
  104. if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
  105. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
  106. }
  107. } else {
  108. MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
  109. }
  110. (void)gettimeofday(&end_time, nullptr);
  111. uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  112. time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  113. MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
  114. root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
  115. return changes;
  116. }
  117. bool IsElementWiseOperator(const std::string &op_name) {
  118. // clang-format off
  119. static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
  120. SOFTMAX, LOG_SOFTMAX, RELU,
  121. SQRT, CAST, POW,
  122. EXP, LOG, COS,
  123. ACOS, LOGICALNOT, NEG,
  124. SQUARE, SIGMOID, ABS,
  125. ACOSH, ASIN, ASINH,
  126. ATAN, ATANH, CEIL,
  127. COSH, EXPM1, LOG1P,
  128. SIN, SINH, TAN,
  129. RSQRT, RECIPROCAL, INV,
  130. ROUND, FLOOR, SIGN,
  131. ERF, ERFC, ZEROSLIKE,
  132. ONESLIKE, BESSELI0E, MOD,
  133. ASSIGN, ASSIGN_ADD, ATAN2,
  134. DIVNONAN, LOGICALAND, ELU,
  135. LOGICALOR, RELU6, SOFTPLUS,
  136. SOFTSIGN, LESS, LESSEQUAL,
  137. BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL,
  138. REPEAT_ELEMENTS};
  139. // clang-format on
  140. auto iter = elementwise_op.find(op_name);
  141. return (iter != elementwise_op.end());
  142. }
  143. bool IsSplittableOperator(const std::string &op_name) {
  144. // clang-format off
  145. static const std::set<std::string> splittable_op =
  146. {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
  147. FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
  148. REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
  149. MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
  150. LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
  151. STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
  152. SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
  153. EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
  154. EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
  155. BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
  156. SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
  157. UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT,
  158. UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT};
  159. // clang-format on
  160. auto iter = splittable_op.find(op_name);
  161. return (iter != splittable_op.end());
  162. }
  163. bool IsAutoParallelCareNode(const CNodePtr &cnode) {
  164. MS_EXCEPTION_IF_NULL(cnode);
  165. ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
  166. if (prim_node == nullptr) {
  167. return false;
  168. }
  169. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
  170. if (prim == nullptr) {
  171. return false;
  172. }
  173. bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
  174. if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) {
  175. MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
  176. } else if (prim->name() == CAST) {
  177. if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
  178. // Do not care CASTs from optimizer
  179. return false;
  180. }
  181. return true;
  182. }
  183. return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
  184. }
  185. // Recording the operators appearing in a for-loop.
  186. // Currently, we assume that the operators in different for-loops are identical, and their traversal
  187. // orderings are also identical.
  188. // Therefore, we create OperatorInfo objects for the operators in a loop (say, loop-3), and reuse them in
  189. // the rest of loops (loop-2, loop-1 and loop-0)
  190. std::set<std::string> ops_in_a_loop_;
  191. // Whether two operators are in different loops; if it is true, then return true.
  192. // If at least one of the two operators is not in the loop, then return false.
  193. // If two operators are in the same loop, the return false.
  194. bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cnode) {
  195. auto a_op_info = a_cnode->user_data<OperatorInfo>();
  196. MS_EXCEPTION_IF_NULL(a_op_info);
  197. auto b_op_info = b_cnode->user_data<OperatorInfo>();
  198. MS_EXCEPTION_IF_NULL(b_op_info);
  199. if ((ops_in_a_loop_.find(a_op_info->name()) == ops_in_a_loop_.end()) ||
  200. (ops_in_a_loop_.find(b_op_info->name()) == ops_in_a_loop_.end())) {
  201. return false;
  202. }
  203. size_t a_loop_index = 0, b_loop_index = 0;
  204. const auto &a_fullname = a_cnode->fullname_with_scope();
  205. if (!GetLoopIndexFromCNode(a_cnode, &a_loop_index)) {
  206. MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << a_fullname << " was not included in the set.";
  207. }
  208. const auto &b_fullname = b_cnode->fullname_with_scope();
  209. if (!GetLoopIndexFromCNode(b_cnode, &b_loop_index)) {
  210. MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << b_fullname << " was not included in the set.";
  211. }
  212. if (a_loop_index == b_loop_index) {
  213. return false;
  214. }
  215. return true;
  216. }
  217. void InitCostGraph() {
  218. if (entire_costgraph == nullptr) {
  219. entire_costgraph = std::make_shared<CostGraph>();
  220. }
  221. MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
  222. CostModelContext::GetInstance()->PrintCostModel();
  223. entire_costgraph->Init();
  224. }
  225. void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
  226. const std::unordered_map<std::string, ValuePtr> &attrs, bool is_last_nodes,
  227. StrategyMap *stra_map, const std::string &strategy_key_name) {
  228. // In this case, the configured strategy should be extracted to help setting cost
  229. StrategyPtr strategyPtr;
  230. if (StrategyFound(attrs)) {
  231. strategyPtr = parallel::ExtractStrategy(attrs);
  232. } else {
  233. strategyPtr = (*stra_map)[strategy_key_name];
  234. }
  235. if (strategyPtr != nullptr) {
  236. if (prim->name() == RESHAPE) {
  237. MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
  238. }
  239. const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
  240. // Set cost for this configured strategy
  241. if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
  242. MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
  243. } else if (fully_use_devices) {
  244. // If configured to fully use devices, then checking for the user-specified strategy
  245. int64_t used_devices = operator_info->used_devices();
  246. MS_EXCEPTION_IF_NULL(g_device_manager);
  247. auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
  248. // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
  249. if (used_devices == 1) {
  250. return;
  251. }
  252. // 'used_devices == -1' means that 'used_devices_' is not set
  253. if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) {
  254. MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
  255. << "but the specified strategy uses device: " << used_devices
  256. << ", total devices: " << total_device_num;
  257. }
  258. }
  259. }
  260. }
  261. OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
  262. StrategyMap *stra_map) {
  263. MS_EXCEPTION_IF_NULL(prim);
  264. MS_EXCEPTION_IF_NULL(cnode);
  265. auto attrs = prim->attrs();
  266. std::vector<Shapes> shape_list = ExtractShape(cnode);
  267. if (shape_list.empty()) {
  268. MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
  269. }
  270. // Create an OperatorInfo instance
  271. OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
  272. MS_EXCEPTION_IF_NULL(operator_info);
  273. // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
  274. std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
  275. if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
  276. MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
  277. return nullptr;
  278. }
  279. // Set the data type for inputs and outputs of this OperatorInfo
  280. auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
  281. auto outputs_type = ExtractOutputTypeByNode(cnode);
  282. std::vector<size_t> outputs_type_length;
  283. outputs_type_length.reserve(outputs_type.size());
  284. std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
  285. GetLengthOfDataType);
  286. if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
  287. MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
  288. return nullptr;
  289. }
  290. if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
  291. MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
  292. return nullptr;
  293. }
  294. // When the 'inputs' contains numerical values for some operators, these values should be extracted from
  295. // ANF graph
  296. auto &inputs = cnode->inputs();
  297. std::vector<ValuePtr> input_value;
  298. for (size_t index = 1; index < inputs.size(); ++index) {
  299. if (inputs[index]->isa<ValueNode>()) {
  300. input_value.push_back(GetValueNode(inputs[index]));
  301. } else {
  302. input_value.emplace_back(nullptr);
  303. }
  304. }
  305. operator_info->set_input_value(input_value);
  306. operator_info->set_outputs_dtype(cnode->Type());
  307. operator_info->set_cnode(cnode);
  308. // key of strategy map
  309. std::string strategy_key_name = "";
  310. auto param_names = NodeParameterName(cnode, -1, 0);
  311. if (!param_names.empty()) {
  312. strategy_key_name = prim->name() + "_" + param_names[0].first;
  313. }
  314. bool load_strategy_from_ckpt =
  315. StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
  316. // If no strategy has been configured for this operator, then candidate strategies are generated for
  317. // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
  318. // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
  319. if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
  320. // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
  321. // BatchParallelInfo operator
  322. operator_info->ComputeBatchSplitFlagList();
  323. if (operator_info->GenerateStrategies(0) != SUCCESS) {
  324. MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
  325. return nullptr;
  326. }
  327. // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
  328. auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
  329. if (approximation) {
  330. operator_info->ApproximateStrategies();
  331. MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
  332. }
  333. } else {
  334. SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
  335. }
  336. return operator_info;
  337. }
  338. bool IsFindWrong(const OperatorInfoPtr current_op_ptr, const std::string &prim_name) {
  339. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  340. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  341. (current_op_ptr->name().find(prim_name + "Info") == std::string::npos);
  342. if (prim_name == GATHERV2) {
  343. is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos);
  344. }
  345. return is_find_wrong;
  346. }
  347. // Using CNode's UniqueIds to construct nodes
  348. Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  349. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  350. // The map from CNode's UniqueId to its operatorInfo
  351. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  352. // The operator_infos in a loop
  353. std::vector<OperatorInfoPtr> operators_in_forloop;
  354. // Key: i-th loop; Value: index of 'operators_in_forloop'
  355. std::map<size_t, size_t> loop_to_ops;
  356. // extract strategy from checkpoint for multi-train
  357. StrategyMap stra_map;
  358. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  359. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  360. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  361. }
  362. }
  363. for (auto &node : all_nodes) {
  364. // NOTE: we only care about splittable Primitive operators
  365. auto cnode = node->cast<CNodePtr>();
  366. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  367. if (bool_result) {
  368. continue;
  369. }
  370. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  371. if (!IsAutoParallelCareNode(cnode)) {
  372. // Needed by rec_parser
  373. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  374. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  375. if (prev_cnode != nullptr) {
  376. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  377. }
  378. }
  379. continue;
  380. }
  381. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  382. MS_EXCEPTION_IF_NULL(prim);
  383. auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
  384. if (search_cnode == from_cnode_to_info.end()) {
  385. size_t loop_index = 0;
  386. bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
  387. const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
  388. if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
  389. const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
  390. if (IsFindWrong(current_op_ptr, prim->name())) {
  391. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  392. << " does not match the Prim: " << prim->name()
  393. << ". The fullname_with_scope: " << cnode->fullname_with_scope();
  394. }
  395. loop_to_ops[loop_index]++;
  396. cnode->set_user_data<OperatorInfo>(current_op_ptr);
  397. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  398. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  399. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  400. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  401. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
  402. continue;
  403. }
  404. bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
  405. auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
  406. if (operator_info == nullptr) {
  407. return FAILED;
  408. }
  409. // Needed by rec_parser
  410. operator_info->set_type(prim->name());
  411. operator_info->set_last_node_flag(is_last_nodes);
  412. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  413. entire_costgraph->AddOperator(operator_info);
  414. cnode->set_user_data<OperatorInfo>(operator_info);
  415. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  416. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  417. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  418. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  419. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info));
  420. if (single_loop && is_in_loop) {
  421. operators_in_forloop.push_back(operator_info);
  422. ops_in_a_loop_.insert(operator_info->name());
  423. loop_to_ops[loop_index]++;
  424. }
  425. // Needed by rec_parser
  426. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  427. } else {
  428. // Two CNODEs' UniqueIds should not be equal
  429. MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
  430. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  431. << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
  432. }
  433. }
  434. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  435. return SUCCESS;
  436. }
  437. void SetOperatorToCNode(const OperatorInfoPtr &current_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) {
  438. if (current_op_ptr == nullptr) {
  439. MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
  440. } else {
  441. if (IsFindWrong(current_op_ptr, prim->name())) {
  442. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  443. << " does not match the Prim: " << prim->name();
  444. }
  445. // Needed by rec_parser
  446. ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
  447. cnode->set_user_data<OperatorInfo>(current_op_ptr);
  448. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  449. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  450. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  451. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  452. }
  453. }
  454. // Using CNode's UniqueIdThroughCopys to construct nodes
  455. Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  456. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  457. // The map from CNode's UniqueIdThroughCopy to its operatorInfo
  458. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  459. // The operator_infos in a loop
  460. std::vector<OperatorInfoPtr> operators_in_forloop;
  461. // Key: i-th loop; Value: index of 'operators_in_forloop'
  462. std::map<size_t, size_t> loop_to_ops;
  463. // extract strategy from checkpoint for multi-train
  464. StrategyMap stra_map;
  465. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
  466. StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  467. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  468. }
  469. for (auto &node : all_nodes) {
  470. // NOTE: we only care about splittable Primitive operators
  471. auto cnode = node->cast<CNodePtr>();
  472. if ((cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)))) {
  473. continue;
  474. }
  475. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  476. if (!IsAutoParallelCareNode(cnode)) {
  477. // Needed by rec_parser
  478. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  479. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  480. if (prev_cnode != nullptr) {
  481. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  482. }
  483. }
  484. continue;
  485. }
  486. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  487. // Find the operatorInfo if it exists
  488. auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
  489. if (search_cnode == from_cnode_to_info.end()) {
  490. size_t loop_index = 0;
  491. bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
  492. const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
  493. bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
  494. if (is_op_created) {
  495. const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
  496. if (IsFindWrong(current_op_ptr, prim->name())) {
  497. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  498. << " does not match the Prim: " << prim->name()
  499. << ". The fullname_with_scope: " << cnode->fullname_with_scope();
  500. }
  501. loop_to_ops[loop_index]++;
  502. cnode->set_user_data<OperatorInfo>(current_op_ptr);
  503. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  504. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  505. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  506. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  507. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), current_op_ptr));
  508. continue;
  509. }
  510. // In this case, the corresponding OperatorInfo is not created, create the new one.
  511. bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
  512. auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
  513. MS_EXCEPTION_IF_NULL(operator_info);
  514. // Needed by rec_parser
  515. operator_info->set_type(prim->name());
  516. operator_info->set_last_node_flag(is_last_nodes);
  517. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  518. entire_costgraph->AddOperator(operator_info);
  519. cnode->set_user_data<OperatorInfo>(operator_info);
  520. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  521. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  522. << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
  523. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  524. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  525. if (single_loop && is_in_loop) {
  526. operators_in_forloop.push_back(operator_info);
  527. ops_in_a_loop_.insert(operator_info->name());
  528. loop_to_ops[loop_index]++;
  529. }
  530. // Needed by rec_parser
  531. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  532. } else {
  533. SetOperatorToCNode(search_cnode->second, prim, cnode);
  534. }
  535. }
  536. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  537. return SUCCESS;
  538. }
  539. void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
  540. const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim,
  541. const PrimitivePtr &prev_prim, size_t output_index, size_t input_index,
  542. size_t *edge_count) {
  543. std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
  544. // If the edge between these two operators already has been added, then the edge will not be added again.
  545. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) {
  546. return;
  547. }
  548. EdgePtr edge_ptr;
  549. MS_LOG(INFO) << "Creating edge: " << edge_name;
  550. if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) {
  551. MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope()
  552. << ", cnode_fullname: " << cnode->fullname_with_scope();
  553. MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
  554. return;
  555. }
  556. const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
  557. bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
  558. (stra_follow && IsElementWiseOperator(prev_prim->name()));
  559. if (follow_strategy) {
  560. // Redistribution in not allowed on the edge.
  561. // Elementwise operators have the same strategy as their previous operators.
  562. edge_ptr =
  563. std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true);
  564. } else {
  565. edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false);
  566. }
  567. // Init costs for this edge
  568. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  569. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  570. }
  571. node_op_info->AddPrevEdge(edge_ptr);
  572. prev_op_info->AddSuccEdge(edge_ptr);
  573. entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
  574. MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name();
  575. (*edge_count)++;
  576. return;
  577. }
  578. void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
  579. // Step 2
  580. MS_LOG(INFO) << "Constructing edges for cost graph begins.";
  581. for (auto &node : all_nodes) {
  582. auto cnode = node->cast<CNodePtr>();
  583. if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
  584. continue;
  585. }
  586. auto &inputs = cnode->inputs();
  587. ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
  588. if (!IsAutoParallelCareNode(cnode)) {
  589. continue;
  590. }
  591. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  592. size_t edge_count = 0;
  593. auto node_op_info = cnode->user_data<OperatorInfo>();
  594. for (size_t i = 1; i < inputs.size(); ++i) {
  595. auto prev_cnode = inputs[i]->cast<CNodePtr>();
  596. bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  597. if (bool_result_prev_cnode) {
  598. continue;
  599. }
  600. ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  601. PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  602. size_t output_index = 0;
  603. while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
  604. (prev_prim->name() == DEPEND)) {
  605. if (IsAutoParallelCareNode(prev_cnode)) {
  606. auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
  607. CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
  608. &edge_count);
  609. break;
  610. } else if (prev_prim->name() == prim::kTupleGetItem) {
  611. // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
  612. // this 'tuple_getitem'
  613. MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
  614. output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
  615. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  616. bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  617. if (bool_result_tuple) {
  618. break;
  619. }
  620. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  621. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  622. if (!IsAutoParallelCareNode(prev_cnode)) {
  623. MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
  624. }
  625. MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
  626. << "and creating an edge between the Operator before "
  627. << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
  628. } else if (prev_prim->name() == DEPEND) {
  629. // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
  630. // this 'depend'
  631. MS_LOG(INFO) << "Jumping the 'depend' operator.";
  632. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  633. bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  634. if (bool_result_depend) {
  635. break;
  636. }
  637. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  638. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  639. MS_LOG(INFO) << "Jumped the 'depend' operator, "
  640. << "and creating an edge between the Operator before "
  641. << "'depend' and the Operator after 'depend'.";
  642. }
  643. }
  644. }
  645. MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
  646. }
  647. // If 'approximation' is enabled, the edges need to be checked have effective costs.
  648. auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
  649. if (approximation) {
  650. entire_costgraph->CheckApproximateCostGraphEdges();
  651. }
  652. MS_LOG(INFO) << "Constructing edges for cost graph ends.";
  653. }
  654. void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
  655. // Step 3
  656. for (auto &node : all_nodes) {
  657. ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode);
  658. auto parameter_name = parameter_users_info.first;
  659. auto target_parameter = parameter_users_info.second.first;
  660. auto target_set = parameter_users_info.second.second;
  661. if (target_set.size() <= 1) {
  662. continue;
  663. }
  664. // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
  665. std::set<std::string> target_without_duplicate;
  666. for (auto &target : target_set) {
  667. auto target_cnode = target.first->cast<CNodePtr>();
  668. auto input_index = target.second;
  669. (void)target_without_duplicate.insert(std::to_string(input_index) +
  670. target_cnode->user_data<OperatorInfo>()->name());
  671. }
  672. if (target_without_duplicate.size() <= 1) {
  673. continue;
  674. }
  675. // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
  676. OperatorInfoPtr tmp_identity_ptr;
  677. bool new_identity = false;
  678. std::string tmp_identity_name;
  679. auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
  680. if (returned_identity != nullptr) {
  681. // In this case, the TmpIdentityInfo instance has already been created
  682. new_identity = false;
  683. tmp_identity_ptr = returned_identity;
  684. tmp_identity_name = tmp_identity_ptr->name();
  685. } else {
  686. // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
  687. new_identity = true;
  688. // 1) extract input shape from this Parameter
  689. MS_EXCEPTION_IF_NULL(target_parameter);
  690. AbstractBasePtr abstract = target_parameter->abstract();
  691. if (abstract == nullptr) {
  692. MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
  693. }
  694. auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
  695. if (input_shape == nullptr) {
  696. MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
  697. }
  698. Shape shape = input_shape->shape();
  699. Shapes inputs_shape = {shape};
  700. Shapes outputs_shape = {shape};
  701. // 2) init the attr
  702. std::unordered_map<std::string, ValuePtr> attr = {};
  703. // Create the TmpIdentity instance
  704. tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
  705. tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
  706. TOTAL_OPS++;
  707. tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
  708. // Set the parameter and type lengths for inputs and outputs
  709. std::vector<bool> is_parameter;
  710. auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
  711. MS_EXCEPTION_IF_NULL(casted_target_parameter);
  712. is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
  713. if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
  714. MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
  715. }
  716. auto node_type = target_parameter->Type();
  717. if (node_type->isa<mindspore::TensorType>()) {
  718. auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
  719. std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
  720. if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
  721. MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
  722. }
  723. } else {
  724. MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
  725. }
  726. // Generate strategies for this TmpIdentityInfo instance;
  727. if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
  728. MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
  729. }
  730. }
  731. // A flag recording whether new edges have been created or not
  732. bool add_identity_edge = false;
  733. // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
  734. for (auto &target : target_set) {
  735. auto target_cnode = target.first->cast<CNodePtr>();
  736. auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
  737. auto input_index = target.second;
  738. auto target_op_info = target_cnode->user_data<OperatorInfo>();
  739. std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
  740. // If the edge between these two operators already has been added, then the edge will not be added again.
  741. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, LongToSize(input_index - 1))) {
  742. continue;
  743. }
  744. std::shared_ptr<Edge> edge_ptr =
  745. std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
  746. // If 'approximation' is enabled, the edges need to be checked have effective costs.
  747. auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
  748. if (approximation) {
  749. target_op_info->ExactStrategiesAndRelatedEdges();
  750. }
  751. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  752. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  753. }
  754. target_op_info->AddPrevEdge(edge_ptr);
  755. tmp_identity_ptr->AddSuccEdge(edge_ptr);
  756. entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
  757. MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
  758. << target_op_info->name();
  759. add_identity_edge = true;
  760. }
  761. if (new_identity && add_identity_edge) {
  762. // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
  763. entire_costgraph->AddOperator(tmp_identity_ptr);
  764. }
  765. }
  766. }
  767. void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
  768. std::unordered_set<std::string> op_cache;
  769. for (auto node : all_nodes) {
  770. auto cnode = node->cast<CNodePtr>();
  771. if (!FindReshape(cnode, &op_cache)) {
  772. continue;
  773. }
  774. MS_ASSERT(cnode->inputs().size() == 3);
  775. // get previous node's strategy_cost_
  776. auto pre_node = cnode->input(1);
  777. if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) {
  778. pre_node = pre_node->cast<CNodePtr>()->input(1);
  779. }
  780. int64_t out_index = 0;
  781. OperatorInfoPtr pre_operator_info;
  782. std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
  783. auto operator_info = cnode->user_data<OperatorInfo>();
  784. if (pre_node->isa<Parameter>()) {
  785. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  786. reshape_info->SetCostForReshapeWithParameter();
  787. pre_operator_info = reshape_info;
  788. pre_stra_costs = reshape_info->strategy_cost();
  789. } else {
  790. if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index, 0)) {
  791. MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
  792. }
  793. pre_stra_costs = pre_operator_info->strategy_cost();
  794. }
  795. // get next node's strategy_cost_
  796. int64_t in_index = 0;
  797. OperatorInfoPtr next_operator_info;
  798. std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
  799. bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, 0);
  800. if (!find_next_node) {
  801. MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
  802. }
  803. // set input_layout and output_layout for reshape.
  804. // init reshape and set cost for each input_layout and output_layout.
  805. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  806. reshape_info->set_pre_operator_name(pre_operator_info->name());
  807. reshape_info->set_pre_operator_index(out_index);
  808. if (find_next_node) {
  809. next_stra_costs = next_operator_info->strategy_cost();
  810. reshape_info->set_next_operator_name(next_operator_info->name());
  811. reshape_info->set_next_operator_index(in_index);
  812. }
  813. bool is_prev_param = pre_node->isa<Parameter>();
  814. if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
  815. SUCCESS) {
  816. MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
  817. }
  818. }
  819. }
  820. Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  821. // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
  822. // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
  823. // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
  824. // for each OperatorInfo;
  825. // Step 1.1: Deal with 'Reshape':
  826. // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
  827. // layout as its output layout.
  828. // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
  829. // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
  830. // for each edge, based on the strategies of two OperatorInfos;
  831. // Step 3: Augment the costgraph:
  832. // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
  833. // operator for this Parameter, and add an edge for the use of this Parameter by each
  834. // subsequent operator;
  835. // Step 3.1: Calculate memory usage:
  836. // note the memory usage calculation is different in training phase and inference phase.
  837. // Step 4: Run the Dynamic Programming algorithm:
  838. // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
  839. // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
  840. // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
  841. // runs on each of them.
  842. //
  843. // OUTPUT: the determined strategy for each operator.
  844. InitCostGraph();
  845. // Step 1
  846. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  847. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  848. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  849. << entire_costgraph->GetOperators().size() << " operators.";
  850. } else {
  851. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  852. }
  853. } else {
  854. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  855. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  856. << entire_costgraph->GetOperators().size() << " operators.";
  857. } else {
  858. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  859. }
  860. }
  861. // Step 1.1
  862. ReshapeCostCompute(all_nodes);
  863. // Step 2
  864. ConstructCostGraphEdges(all_nodes);
  865. MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  866. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  867. // Step 3: Augment the costgraph.
  868. AugmentCostGraph(all_nodes);
  869. auto num_ops = entire_costgraph->GetOperators().size();
  870. SetOpsNumToExecutor(num_ops);
  871. auto num_edges = entire_costgraph->GetNumEdges();
  872. MS_LOG(INFO) << "After the augmenting procedure, there are " << num_ops << " operators, and " << num_edges
  873. << " edges.";
  874. // Step 3.1: Calculate the memory usage
  875. if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
  876. MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
  877. }
  878. // Step 4: run DP algorithm on the costgraph.
  879. if (GetStrategy(entire_costgraph) != SUCCESS) {
  880. MS_LOG(ERROR) << "Strategy search for cost-graph fails";
  881. return FAILED;
  882. }
  883. MS_LOG(INFO) << "Searching strategy succeeded.";
  884. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  885. MS_LOG(INFO) << "Init selected strategy succeeded.";
  886. } else {
  887. MS_LOG(EXCEPTION) << "Init selected strategy failed.";
  888. }
  889. // print the selected strategy
  890. for (auto &op : entire_costgraph->GetOperators()) {
  891. StrategyPtr s_strategy = op->selected_strategy();
  892. MS_LOG(INFO) << op->name() << " : The strategy is:";
  893. PrintStrategy(s_strategy);
  894. }
  895. ops_in_a_loop_.clear();
  896. return SUCCESS;
  897. }
  898. std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
  899. std::vector<std::vector<std::string>> input_tensor_names) {
  900. for (size_t j = 0; j < input_tensor_names.size(); j++) {
  901. for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
  902. if (it->first == input_tensor_names[j][k]) {
  903. input_tensor_names[j][k] = it->second;
  904. break;
  905. }
  906. }
  907. }
  908. return input_tensor_names;
  909. }
  910. CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
  911. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  912. if (prim->name() == prim::kTupleGetItem || prim->name() == DEPEND) {
  913. auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
  914. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  915. return nullptr;
  916. }
  917. auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  918. while (prev_prim->name() == prim::kTupleGetItem || prev_prim->name() == DEPEND) {
  919. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  920. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  921. return nullptr;
  922. }
  923. prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  924. }
  925. return prev_cnode;
  926. }
  927. return nullptr;
  928. }
  929. void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
  930. size_t iter_ops = 0;
  931. for (auto op : entire_costgraph->GetOperators()) {
  932. if (op->name() == name) {
  933. break;
  934. }
  935. iter_ops = iter_ops + 1;
  936. }
  937. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  938. for (size_t i = 0; i < input_tensor_names.size(); i++) {
  939. for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
  940. if (input_tensor_names[i][j] == uniqueid) {
  941. input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
  942. }
  943. }
  944. }
  945. entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
  946. }
  947. Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  948. InitCostGraph();
  949. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  950. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  951. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  952. << entire_costgraph->GetOperators().size() << " operators.";
  953. } else {
  954. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  955. }
  956. } else {
  957. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  958. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  959. << entire_costgraph->GetOperators().size() << " operators.";
  960. } else {
  961. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  962. }
  963. }
  964. ReshapeCostCompute(all_nodes);
  965. auto ops = entire_costgraph->GetOperators();
  966. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  967. auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
  968. for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
  969. input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
  970. }
  971. std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
  972. std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
  973. std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
  974. graph = EliminateGraph(graph, eli_list, index_list);
  975. size_t num_device = g_device_manager->DeviceNum();
  976. const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
  977. if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
  978. MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
  979. } else {
  980. MS_LOG(ERROR) << "PartitionForAllDevices failed.";
  981. return FAILED;
  982. }
  983. bool is_training = true;
  984. if (!root->has_flag(TRAINING)) {
  985. is_training = false;
  986. }
  987. GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training);
  988. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  989. MS_LOG(INFO) << "Init selected strategy succeeded.";
  990. } else {
  991. MS_LOG(ERROR) << "Init selected strategy failed.";
  992. return FAILED;
  993. }
  994. // print the selected strategy
  995. for (auto &op : entire_costgraph->GetOperators()) {
  996. StrategyPtr s_strategy = op->selected_strategy();
  997. MS_LOG(INFO) << op->name() << " : The strategy is:";
  998. PrintStrategy(s_strategy);
  999. }
  1000. return SUCCESS;
  1001. }
  1002. } // namespace parallel
  1003. } // namespace mindspore