You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel.cc 52 kB

5 years ago
5 years ago
5 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/step_auto_parallel.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "ir/anf.h"
  28. #include "ir/param_value_py.h"
  29. #include "ir/tensor.h"
  30. #include "optimizer/opt.h"
  31. #include "optimizer/optimizer.h"
  32. #include "parallel/auto_parallel/dp_algo_costmodel.h"
  33. #include "parallel/auto_parallel/edge_costmodel.h"
  34. #include "parallel/auto_parallel/graph_costmodel.h"
  35. #include "parallel/auto_parallel/rec_core/rec_generate_strategy.h"
  36. #include "parallel/auto_parallel/rec_core/rec_parse_graph.h"
  37. #include "parallel/auto_parallel/rec_core/rec_partition.h"
  38. #include "parallel/context.h"
  39. #include "parallel/ops_info/tmp_identity_info.h"
  40. #include "parallel/ops_info/reshape_info.h"
  41. #include "parallel/step_parallel.h"
  42. #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
  43. #include "pipeline/parse/python_adapter.h"
  44. #include "pipeline/pipeline.h"
  45. namespace mindspore {
  46. namespace parallel {
  47. bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
  48. MS_EXCEPTION_IF_NULL(root);
  49. MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
  50. std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
  51. // assume no change to graph
  52. bool changes = false;
  53. // control whether use model_parallel mode
  54. if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
  55. root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
  56. return changes;
  57. }
  58. // check whether strategy_search_mode is valid
  59. std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
  60. if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
  61. // Setting searching mode: dynanic programming as default.
  62. strategy_search_mode = DYNAMIC_PROGRAMMING;
  63. MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
  64. }
  65. struct timeval start_time, end_time;
  66. (void)gettimeofday(&start_time, nullptr);
  67. if (MsContext::GetInstance()->save_graphs_flag()) {
  68. draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
  69. }
  70. MS_LOG(INFO) << "Now entering step auto parallel";
  71. TOTAL_OPS = 0;
  72. AnfNodePtr ret = root->get_return();
  73. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  74. if (ParallelInit() != SUCCESS) {
  75. MS_LOG(EXCEPTION) << "Parallel init failed";
  76. }
  77. // mark the forward cnodes, parallel only care these nodes
  78. MarkForwardCNode(root);
  79. if (FindCommunicationOp(all_nodes)) {
  80. MS_LOG(EXCEPTION) << "The graph contain communication op";
  81. }
  82. // search parallelization strategy
  83. if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
  84. if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
  85. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
  86. }
  87. } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
  88. if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
  89. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
  90. }
  91. } else {
  92. MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
  93. }
  94. (void)gettimeofday(&end_time, nullptr);
  95. uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  96. time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  97. MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
  98. root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true;
  99. return changes;
  100. }
  101. // Given the node, return whether each input is a parameter or a output of a operator.
  102. // The returned boolean vector should be the same order of the inputs, thus its implementation
  103. // is closely consistent with ExtractShape() in step_parallel.cc
  104. std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
  105. std::vector<bool> is_parameter;
  106. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  107. for (size_t i = 1; i < node_inputs.size(); ++i) {
  108. auto input = node_inputs[i];
  109. if (input->isa<Parameter>()) {
  110. auto input_parameter = input->cast<ParameterPtr>();
  111. if (input_parameter->has_default()) {
  112. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
  113. bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
  114. is_parameter.push_back(require_grad);
  115. } else {
  116. is_parameter.push_back(false);
  117. }
  118. } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
  119. is_parameter.push_back(false);
  120. }
  121. }
  122. return is_parameter;
  123. }
  124. // Given the type, return the number of bytes to represent this type
  125. size_t GetLengthOfDataType(const TypePtr &type) {
  126. switch (type->type_id()) {
  127. case kNumberTypeBool:
  128. return sizeof(bool);
  129. case kNumberTypeInt8:
  130. return sizeof(int8_t);
  131. case kNumberTypeInt16:
  132. return sizeof(int16_t);
  133. case kNumberTypeInt32:
  134. return sizeof(int32_t);
  135. case kNumberTypeInt64:
  136. return sizeof(int64_t);
  137. case kNumberTypeUInt8:
  138. return sizeof(uint8_t);
  139. case kNumberTypeUInt16:
  140. return sizeof(uint16_t);
  141. case kNumberTypeUInt32:
  142. return sizeof(uint32_t);
  143. case kNumberTypeUInt64:
  144. return sizeof(uint64_t);
  145. case kNumberTypeFloat16:
  146. return sizeof(float) / 2;
  147. case kNumberTypeFloat32:
  148. return sizeof(float);
  149. case kNumberTypeFloat64:
  150. return sizeof(double);
  151. case kNumberTypeInt:
  152. return sizeof(int);
  153. case kNumberTypeUInt:
  154. return sizeof(unsigned int);
  155. case kNumberTypeFloat:
  156. return sizeof(float);
  157. default:
  158. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  159. }
  160. }
  161. size_t GetInputsTypeLen(const AnfNodePtr &input) {
  162. MS_EXCEPTION_IF_NULL(input);
  163. if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
  164. MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
  165. }
  166. size_t input_type_len = 0;
  167. auto type = input->Type();
  168. MS_EXCEPTION_IF_NULL(type);
  169. if (type->isa<mindspore::TensorType>()) {
  170. auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
  171. input_type_len = GetLengthOfDataType(input_element_type);
  172. } else {
  173. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  174. }
  175. return input_type_len;
  176. }
  177. std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
  178. MS_EXCEPTION_IF_NULL(node);
  179. std::vector<size_t> inputs_type_len;
  180. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  181. // extract input element length
  182. for (auto &input : node_inputs) {
  183. if (IsValueNode<RefKey>(input)) {
  184. auto func_graph = node->func_graph();
  185. MS_EXCEPTION_IF_NULL(func_graph);
  186. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
  187. if (parameters.size() != 1) {
  188. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  189. }
  190. inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
  191. } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
  192. // extract input shape from parameter and apply node
  193. inputs_type_len.push_back(GetInputsTypeLen(input));
  194. }
  195. }
  196. return inputs_type_len;
  197. }
  198. std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
  199. MS_EXCEPTION_IF_NULL(node);
  200. std::vector<TypePtr> outputs_type;
  201. // extract output element type
  202. auto primary_output_type = node->Type();
  203. MS_EXCEPTION_IF_NULL(primary_output_type);
  204. if (primary_output_type->isa<mindspore::Tuple>()) {
  205. // in this case, the output is a tuple
  206. auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
  207. auto elements = tuple_output_type->elements();
  208. for (auto &ele : elements) {
  209. if (ele->isa<mindspore::TensorType>()) {
  210. auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
  211. outputs_type.push_back(ele_element_type);
  212. } else {
  213. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  214. }
  215. }
  216. } else {
  217. // in this case, the output is a single tensor
  218. if (primary_output_type->isa<mindspore::TensorType>()) {
  219. auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
  220. outputs_type.push_back(element_type);
  221. } else {
  222. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  223. }
  224. }
  225. return outputs_type;
  226. }
  227. bool IsElementWiseOperator(const std::string &op_name) {
  228. static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU,
  229. SQRT, CAST, POW, EXP, LOG, COS,
  230. ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID};
  231. auto iter = elementwise_op.find(op_name);
  232. return (iter != elementwise_op.end());
  233. }
  234. bool IsSplittableOperator(const std::string &op_name) {
  235. // clang-format off
  236. static const std::set<std::string> splittable_op =
  237. {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
  238. FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
  239. REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
  240. MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
  241. LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
  242. STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2,
  243. SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
  244. // clang-format on
  245. auto iter = splittable_op.find(op_name);
  246. return (iter != splittable_op.end());
  247. }
  248. bool IsAutoParallelCareNode(const CNodePtr &cnode) {
  249. MS_EXCEPTION_IF_NULL(cnode);
  250. ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
  251. if (prim_node == nullptr) {
  252. return false;
  253. }
  254. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
  255. if (prim == nullptr) {
  256. return false;
  257. }
  258. bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
  259. if (bool_result) {
  260. MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
  261. } else if (prim->name() == CAST) {
  262. return true;
  263. }
  264. return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
  265. }
  266. OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
  267. MS_EXCEPTION_IF_NULL(prim);
  268. MS_EXCEPTION_IF_NULL(cnode);
  269. auto attrs = prim->attrs();
  270. std::vector<Shapes> shape_list = ExtractShape(cnode);
  271. if (shape_list.empty()) {
  272. MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
  273. }
  274. // Create an OperatorInfo instance
  275. OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
  276. MS_EXCEPTION_IF_NULL(operator_info);
  277. // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
  278. std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
  279. if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
  280. MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
  281. return nullptr;
  282. }
  283. // Set the data type for inputs and outputs of this OperatorInfo
  284. auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
  285. auto outputs_type = ExtractOutputTypeByNode(cnode);
  286. std::vector<size_t> outputs_type_length;
  287. outputs_type_length.reserve(outputs_type.size());
  288. std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
  289. GetLengthOfDataType);
  290. if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
  291. MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
  292. return nullptr;
  293. }
  294. if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
  295. MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
  296. return nullptr;
  297. }
  298. // When the 'inputs' contains numerical values for some operators, these values should be extracted from
  299. // ANF graph
  300. auto &inputs = cnode->inputs();
  301. std::vector<ValuePtr> input_value;
  302. for (size_t index = 1; index < inputs.size(); ++index) {
  303. if (inputs[index]->isa<ValueNode>()) {
  304. input_value.push_back(GetValueNode(inputs[index]));
  305. } else {
  306. input_value.emplace_back(nullptr);
  307. }
  308. }
  309. operator_info->set_input_value(input_value);
  310. operator_info->set_outputs_dtype(cnode->Type());
  311. operator_info->set_cnode(cnode);
  312. // key of strategy map
  313. std::string strategy_key_name = NodeParameterName(cnode);
  314. bool load_strategy_from_ckpt =
  315. StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
  316. // If no strategy has been configured for this operator, then candidate strategies are generated for
  317. // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
  318. // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
  319. if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
  320. // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
  321. // BatchParallelInfo operator
  322. operator_info->ComputeBatchSplitFlagList();
  323. if (operator_info->GenerateStrategies(0) != SUCCESS) {
  324. MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
  325. return nullptr;
  326. }
  327. } else {
  328. // In this case, the configured strategy should be extracted to help setting cost
  329. StrategyPtr strategyPtr;
  330. if (load_strategy_from_ckpt) {
  331. strategyPtr = (*stra_map)[strategy_key_name];
  332. } else {
  333. strategyPtr = parallel::ExtractStrategy(attrs);
  334. }
  335. if (strategyPtr != nullptr) {
  336. if (prim->name() == RESHAPE) {
  337. MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
  338. }
  339. // Set cost for this configured strategy
  340. if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
  341. MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
  342. } else if (FULLY_USE_DEVICES) {
  343. // If configured to fully use devices, then checking for the user-specified strategy
  344. int32_t used_devices = operator_info->used_devices();
  345. MS_EXCEPTION_IF_NULL(g_device_manager);
  346. auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
  347. // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
  348. if (used_devices == 1) {
  349. return operator_info;
  350. }
  351. // 'used_devices == -1' means that 'used_devices_' is not set
  352. if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) {
  353. MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
  354. << "but the specified strategy uses device: " << used_devices
  355. << ", total devices: " << total_device_num;
  356. }
  357. }
  358. }
  359. }
  360. return operator_info;
  361. }
  362. // Using CNode's UniqueIds to construct nodes
  363. Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  364. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  365. entire_costgraph = std::make_shared<CostGraph>();
  366. entire_costgraph->SetDeviceMemoryAndCostParameter();
  367. // The map from CNode's UniqueId to its operatorInfo
  368. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  369. // extract strategy from checkpoint for multi-train
  370. StrategyMap stra_map;
  371. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  372. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  373. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  374. }
  375. }
  376. // Step 1
  377. for (auto &node : all_nodes) {
  378. // NOTE: we only care about splittable Primitive operators
  379. auto cnode = node->cast<CNodePtr>();
  380. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  381. if (bool_result) {
  382. continue;
  383. }
  384. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  385. if (!IsAutoParallelCareNode(cnode)) {
  386. // Needed by rec_parser
  387. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  388. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  389. if (prev_cnode != nullptr) {
  390. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  391. }
  392. }
  393. continue;
  394. }
  395. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  396. MS_EXCEPTION_IF_NULL(prim);
  397. auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
  398. if (search_cnode == from_cnode_to_info.end()) {
  399. auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
  400. if (operator_info == nullptr) {
  401. return FAILED;
  402. }
  403. // Needed by rec_parser
  404. operator_info->set_type(prim->name());
  405. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  406. entire_costgraph->AddOperator(operator_info);
  407. (void)cnode->set_operator_info(operator_info);
  408. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  409. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  410. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  411. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  412. // Needed by rec_parser
  413. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  414. } else {
  415. // Two CNODEs' UniqueIds should not be equal
  416. MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
  417. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  418. << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
  419. }
  420. }
  421. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  422. return SUCCESS;
  423. }
  424. // Using CNode's UniqueIdThroughCopys to construct nodes
  425. Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  426. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  427. entire_costgraph = std::make_shared<CostGraph>();
  428. entire_costgraph->SetDeviceMemoryAndCostParameter();
  429. // The map from CNode's UniqueIdThroughCopy to its operatorInfo
  430. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  431. // extract strategy from checkpoint for multi-train
  432. StrategyMap stra_map;
  433. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  434. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  435. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  436. }
  437. }
  438. for (auto &node : all_nodes) {
  439. // NOTE: we only care about splittable Primitive operators
  440. auto cnode = node->cast<CNodePtr>();
  441. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  442. if (bool_result) {
  443. continue;
  444. }
  445. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  446. if (!IsAutoParallelCareNode(cnode)) {
  447. // Needed by rec_parser
  448. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  449. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  450. if (prev_cnode != nullptr) {
  451. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  452. }
  453. }
  454. continue;
  455. }
  456. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  457. // Find the operatorInfo if it exists
  458. auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
  459. if (search_cnode == from_cnode_to_info.end()) {
  460. // In this case, the corresponding OperatorInfo is not created, create the new one.
  461. auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
  462. if (operator_info == nullptr) {
  463. return FAILED;
  464. }
  465. // Needed by rec_parser
  466. operator_info->set_type(prim->name());
  467. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  468. entire_costgraph->AddOperator(operator_info);
  469. (void)cnode->set_operator_info(operator_info);
  470. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  471. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  472. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  473. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  474. // Needed by rec_parser
  475. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  476. } else {
  477. auto current_op_ptr = search_cnode->second;
  478. if (current_op_ptr == nullptr) {
  479. MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
  480. } else {
  481. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  482. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  483. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  484. if (is_find_wrong) {
  485. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  486. << " does not match the Prim: " << prim->name();
  487. }
  488. (void)cnode->set_operator_info(current_op_ptr);
  489. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  490. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  491. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  492. }
  493. }
  494. }
  495. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  496. return SUCCESS;
  497. }
  498. void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
  499. // Step 2
  500. MS_LOG(INFO) << "Constructing edges for cost graph begins.";
  501. for (auto &node : all_nodes) {
  502. auto cnode = node->cast<CNodePtr>();
  503. bool bool_result_cnode = (cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0));
  504. if (bool_result_cnode) {
  505. continue;
  506. }
  507. auto &inputs = cnode->inputs();
  508. ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
  509. if (!IsAutoParallelCareNode(cnode)) {
  510. continue;
  511. }
  512. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  513. size_t edge_count = 0;
  514. for (size_t i = 1; i < inputs.size(); ++i) {
  515. auto prev_cnode = inputs[i]->cast<CNodePtr>();
  516. bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  517. if (bool_result_prev_cnode) {
  518. continue;
  519. }
  520. ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  521. PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  522. size_t output_index = 0;
  523. bool bool_result =
  524. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  525. while (bool_result) {
  526. if (IsAutoParallelCareNode(prev_cnode)) {
  527. std::string edge_name =
  528. prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
  529. // If the edge between these two operators already has been added, then the edge will not be added again.
  530. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
  531. break;
  532. }
  533. EdgePtr edge_ptr;
  534. MS_LOG(INFO) << "Creating edge: " << edge_name;
  535. bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
  536. (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
  537. if (follow_strategy) {
  538. // Redistribution in not allowed on the edge.
  539. // Elementwise operators have the same strategy as their previous operators.
  540. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  541. output_index, i - 1, false, true);
  542. } else {
  543. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  544. output_index, i - 1, false);
  545. }
  546. // Init costs for this edge
  547. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  548. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  549. }
  550. cnode->operator_info()->AddPrevEdge(edge_ptr);
  551. prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
  552. entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
  553. MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
  554. << cnode->operator_info()->name();
  555. edge_count++;
  556. break;
  557. } else if (prev_prim->name() == TUPLE_GETITEM) {
  558. // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
  559. // this 'tuple_getitem'
  560. MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
  561. output_index = IntToSize(GetValue<int>(GetValueNode(prev_cnode->input(2))));
  562. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  563. bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  564. if (bool_result_tuple) {
  565. break;
  566. }
  567. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  568. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  569. if (!IsAutoParallelCareNode(prev_cnode)) {
  570. MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
  571. }
  572. MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
  573. << "and creating an edge between the Operator before "
  574. << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
  575. } else if (prev_prim->name() == DEPEND) {
  576. // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
  577. // this 'depend'
  578. MS_LOG(INFO) << "Jumping the 'depend' operator.";
  579. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  580. bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  581. if (bool_result_depend) {
  582. break;
  583. }
  584. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  585. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  586. MS_LOG(INFO) << "Jumped the 'depend' operator, "
  587. << "and creating an edge between the Operator before "
  588. << "'depend' and the Operator after 'depend'.";
  589. }
  590. bool_result =
  591. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  592. }
  593. }
  594. MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
  595. }
  596. MS_LOG(INFO) << "Constructing edges for cost graph ends.";
  597. }
  598. std::pair<AnfNodePtr, std::vector<AnfNodePtr>> CNodeWithRefKeys(const AnfNodePtr &cnode) {
  599. MS_EXCEPTION_IF_NULL(cnode);
  600. std::vector<AnfNodePtr> refkeys;
  601. if (cnode->isa<CNode>()) {
  602. auto cnode_ptr = cnode->cast<CNodePtr>();
  603. auto inputs = cnode_ptr->inputs();
  604. for (auto &one_input : inputs) {
  605. if (IsValueNode<RefKey>(one_input)) {
  606. refkeys.push_back(one_input);
  607. }
  608. }
  609. if (refkeys.size() >= 1) {
  610. return std::make_pair(cnode, refkeys);
  611. }
  612. }
  613. return {nullptr, refkeys};
  614. }
  615. void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
  616. // Step 3
  617. for (auto &node : all_nodes) {
  618. auto cnode_with_refkeys = CNodeWithRefKeys(node);
  619. if ((!node->isa<Parameter>()) && (cnode_with_refkeys.first == nullptr)) {
  620. continue;
  621. }
  622. std::string parameter_name;
  623. AnfNodePtr target_parameter = nullptr;
  624. AnfNodeIndexSet target_set;
  625. if (cnode_with_refkeys.first != nullptr) {
  626. // Dealing with the RefKey case
  627. auto refkeys = cnode_with_refkeys.second;
  628. auto cnode = cnode_with_refkeys.first;
  629. auto cnode_ptr = cnode->cast<CNodePtr>();
  630. if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
  631. continue;
  632. }
  633. if (!IsAutoParallelCareNode(cnode_ptr)) {
  634. continue;
  635. }
  636. if (refkeys.size() > 1) {
  637. MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
  638. }
  639. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  640. auto cnode_func_graph = cnode->func_graph();
  641. MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
  642. // Find the RefKey being used
  643. auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
  644. for (auto &candidate : candidate_set_by_refkey) {
  645. auto candidate_node = candidate.first;
  646. auto c = candidate_node->cast<CNodePtr>();
  647. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  648. continue;
  649. }
  650. if (!IsAutoParallelCareNode(c)) {
  651. continue;
  652. }
  653. target_set.add(candidate);
  654. }
  655. // Find the corresponding Parameter being used
  656. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
  657. if (parameters.size() != 1) {
  658. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  659. }
  660. parameter_name = parameters[0]->cast<ParameterPtr>()->name();
  661. target_parameter = parameters[0];
  662. auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
  663. for (auto &candidate : candidate_set_by_para) {
  664. auto candidate_node = candidate.first;
  665. auto c = candidate_node->cast<CNodePtr>();
  666. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  667. continue;
  668. }
  669. if (!IsAutoParallelCareNode(c)) {
  670. continue;
  671. }
  672. (void)target_set.insert(candidate);
  673. }
  674. } else if (node->isa<Parameter>()) {
  675. // Dealing with the Parameter case
  676. MS_EXCEPTION_IF_NULL(node->func_graph());
  677. MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
  678. auto candidate_set = node->func_graph()->manager()->node_users()[node];
  679. for (auto &candidate : candidate_set) {
  680. auto candidate_node = candidate.first;
  681. auto c = candidate_node->cast<CNodePtr>();
  682. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  683. continue;
  684. }
  685. if (!IsAutoParallelCareNode(c)) {
  686. continue;
  687. }
  688. (void)target_set.insert(candidate);
  689. }
  690. // In this case, node is a Parameter
  691. parameter_name = node->cast<ParameterPtr>()->name();
  692. target_parameter = node;
  693. }
  694. if (target_set.size() <= 1) {
  695. continue;
  696. }
  697. // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
  698. std::set<std::string> target_without_duplicate;
  699. for (auto &target : target_set) {
  700. auto target_cnode = target.first->cast<CNodePtr>();
  701. auto input_index = target.second;
  702. (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
  703. }
  704. if (target_without_duplicate.size() <= 1) {
  705. continue;
  706. }
  707. // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
  708. OperatorInfoPtr tmp_identity_ptr;
  709. bool new_identity = false;
  710. std::string tmp_identity_name;
  711. auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
  712. if (returned_identity != nullptr) {
  713. // In this case, the TmpIdentityInfo instance has already been created
  714. new_identity = false;
  715. tmp_identity_ptr = returned_identity;
  716. tmp_identity_name = tmp_identity_ptr->name();
  717. } else {
  718. // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
  719. new_identity = true;
  720. // 1) extract input shape from this Parameter
  721. MS_EXCEPTION_IF_NULL(target_parameter);
  722. AbstractBasePtr abstract = target_parameter->abstract();
  723. if (abstract == nullptr) {
  724. MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
  725. }
  726. auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
  727. if (input_shape == nullptr) {
  728. MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
  729. }
  730. std::vector<int> shape_int = input_shape->shape();
  731. Shape shape;
  732. (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
  733. [](int sub_shape) { return static_cast<int32_t>(sub_shape); });
  734. Shapes inputs_shape = {shape};
  735. Shapes outputs_shape = {shape};
  736. // 2) init the attr
  737. std::unordered_map<std::string, ValuePtr> attr = {};
  738. // Create the TmpIdentity instance
  739. tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
  740. tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
  741. TOTAL_OPS++;
  742. tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
  743. // Set the parameter and type lengths for inputs and outputs
  744. std::vector<bool> is_parameter;
  745. auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
  746. MS_EXCEPTION_IF_NULL(casted_target_parameter);
  747. if (casted_target_parameter->has_default()) {
  748. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(casted_target_parameter->default_param());
  749. bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
  750. is_parameter.push_back(require_grad);
  751. } else {
  752. is_parameter.push_back(false);
  753. }
  754. if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
  755. MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
  756. }
  757. auto node_type = target_parameter->Type();
  758. if (node_type->isa<mindspore::TensorType>()) {
  759. auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
  760. std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
  761. if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
  762. MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
  763. }
  764. } else {
  765. MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
  766. }
  767. // Generate strategies for this TmpIdentityInfo instance;
  768. if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
  769. MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
  770. }
  771. }
  772. // A flag recording whether new edges have been created or not
  773. bool add_identity_edge = false;
  774. // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
  775. for (auto &target : target_set) {
  776. auto target_cnode = target.first->cast<CNodePtr>();
  777. auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
  778. auto input_index = target.second;
  779. std::string edge_name =
  780. std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
  781. // If the edge between these two operators already has been added, then the edge will not be added again.
  782. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
  783. continue;
  784. }
  785. std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
  786. edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
  787. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  788. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  789. }
  790. target_cnode->operator_info()->AddPrevEdge(edge_ptr);
  791. tmp_identity_ptr->AddSuccEdge(edge_ptr);
  792. entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
  793. MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
  794. << target_cnode->operator_info()->name();
  795. add_identity_edge = true;
  796. }
  797. if (new_identity && add_identity_edge) {
  798. // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
  799. entire_costgraph->AddOperator(tmp_identity_ptr);
  800. }
  801. }
  802. }
  803. bool FindReshape(const CNodePtr &cnode) {
  804. if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
  805. return false;
  806. }
  807. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  808. if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
  809. return false;
  810. }
  811. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  812. MS_EXCEPTION_IF_NULL(prim);
  813. OperatorInfoPtr operator_info = cnode->operator_info();
  814. if (operator_info == nullptr) {
  815. MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
  816. }
  817. if (prim->name() != RESHAPE) {
  818. return false;
  819. }
  820. return true;
  821. }
  822. // find previous node, then obtain its strategy_cost_ vector to get its layout vector.
  823. bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) {
  824. // if previous node is a parameter, handle it in the outsize.
  825. if (node->isa<Parameter>()) {
  826. return false;
  827. }
  828. if (!node->isa<CNode>()) {
  829. return false;
  830. }
  831. CNodePtr cnode = node->cast<CNodePtr>();
  832. if (!IsValueNode<Primitive>(cnode->input(0))) {
  833. return false;
  834. }
  835. if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
  836. *pre_operator_info = cnode->operator_info();
  837. *out_index = 0;
  838. return true;
  839. }
  840. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  841. PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
  842. if (prim->name() == TUPLE_GETITEM) {
  843. *out_index = GetTupleGetItemIndex(cnode);
  844. // find tuple_get_item's previous node
  845. auto pre_node = cnode->input(1);
  846. if (!pre_node->isa<CNode>()) {
  847. MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
  848. }
  849. CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
  850. if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
  851. *pre_operator_info = pre_cnode->operator_info();
  852. return true;
  853. }
  854. return false;
  855. }
  856. for (size_t index = 0; index < cnode->inputs().size(); ++index) {
  857. if (prim->name() == DEPEND && index != 1) {
  858. continue;
  859. }
  860. if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
  861. continue;
  862. }
  863. return true;
  864. }
  865. MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
  866. return false;
  867. }
  868. // find next node, then obtain its strategy_cost_ vector to get its layout vector.
  869. // if reshape's output connect to several primitive, return the first layout found
  870. bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) {
  871. MS_EXCEPTION_IF_NULL(cnode);
  872. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  873. FuncGraphManagerPtr manager = cnode->func_graph()->manager();
  874. MS_EXCEPTION_IF_NULL(manager);
  875. AnfNodeIndexSet node_set = manager->node_users()[cnode];
  876. for (auto &node_pair : node_set) {
  877. CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
  878. if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
  879. continue;
  880. }
  881. ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
  882. MS_EXCEPTION_IF_NULL(prim_anf_node);
  883. PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
  884. MS_EXCEPTION_IF_NULL(node_prim);
  885. MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
  886. if (node_prim->name() == DEPEND && node_pair.second != 1) {
  887. continue;
  888. }
  889. if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
  890. MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
  891. *next_operator_info = use_apply->operator_info();
  892. *in_index = node_pair.second - 1;
  893. return true;
  894. }
  895. MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
  896. << " " << (use_apply->operator_info() != nullptr);
  897. if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
  898. return true;
  899. }
  900. }
  901. return false;
  902. }
  903. void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
  904. for (auto node : all_nodes) {
  905. auto cnode = node->cast<CNodePtr>();
  906. if (!FindReshape(cnode)) {
  907. continue;
  908. }
  909. MS_ASSERT(cnode->inputs().size() == 3);
  910. // get previous node's strategy_cost_
  911. auto pre_node = cnode->input(1);
  912. int32_t out_index = 0;
  913. OperatorInfoPtr pre_operator_info;
  914. std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
  915. if (pre_node->isa<Parameter>()) {
  916. OperatorInfoPtr operator_info = cnode->operator_info();
  917. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  918. reshape_info->SetCostForReshapeWithParameter();
  919. pre_operator_info = reshape_info;
  920. pre_stra_costs = reshape_info->strategy_cost();
  921. } else {
  922. if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
  923. MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
  924. }
  925. pre_stra_costs = pre_operator_info->strategy_cost();
  926. }
  927. // get next node's strategy_cost_
  928. int32_t in_index = 0;
  929. OperatorInfoPtr next_operator_info;
  930. std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
  931. bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
  932. if (!find_next_node) {
  933. MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
  934. }
  935. // set input_layout and output_layout for reshape.
  936. // init reshape and set cost for each input_layout and output_layout.
  937. OperatorInfoPtr operator_info = cnode->operator_info();
  938. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  939. reshape_info->set_pre_operator_name(pre_operator_info->name());
  940. reshape_info->set_pre_operator_index(out_index);
  941. if (find_next_node) {
  942. next_stra_costs = next_operator_info->strategy_cost();
  943. reshape_info->set_next_operator_name(next_operator_info->name());
  944. reshape_info->set_next_operator_index(in_index);
  945. }
  946. bool is_prev_param = pre_node->isa<Parameter>();
  947. if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
  948. SUCCESS) {
  949. MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!";
  950. }
  951. }
  952. }
  953. Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  954. // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
  955. // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
  956. // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
  957. // for each OperatorInfo;
  958. // Step 1.1: Deal with 'Reshape':
  959. // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
  960. // layout as its output layout.
  961. // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
  962. // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
  963. // for each edge, based on the strategies of two OperatorInfos;
  964. // Step 3: Augment the costgraph:
  965. // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
  966. // operator for this Parameter, and add an edge for the use of this Parameter by each
  967. // subsequent operator;
  968. // Step 3.1: Calculate memory usage:
  969. // note the memory usage calculation is different in training phase and inference phase.
  970. // Step 4: Run the Dynamic Programming algorithm:
  971. // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
  972. // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
  973. // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
  974. // runs on each of them.
  975. //
  976. // OUTPUT: the determined strategy for each operator.
  977. // Step 1
  978. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  979. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  980. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  981. << entire_costgraph->GetOperators().size() << " operators.";
  982. } else {
  983. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  984. }
  985. } else {
  986. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  987. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  988. << entire_costgraph->GetOperators().size() << " operators.";
  989. } else {
  990. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  991. }
  992. }
  993. // Step 1.1
  994. ReshapeCostCompute(all_nodes);
  995. // Step 2
  996. ConstructCostGraphEdges(all_nodes);
  997. MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  998. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  999. // Step 3: Augment the costgraph.
  1000. AugmentCostGraph(all_nodes);
  1001. MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
  1002. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  1003. // Step 3.1: Calculate the memory usage
  1004. if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
  1005. MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
  1006. }
  1007. // Step 4: run DP algorithm on the costgraph.
  1008. if (GetStrategy(entire_costgraph) != SUCCESS) {
  1009. MS_LOG(ERROR) << "Strategy search for cost-graph fails";
  1010. return FAILED;
  1011. }
  1012. MS_LOG(INFO) << "Searching strategy succeeded.";
  1013. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  1014. MS_LOG(INFO) << "Init selected strategy succeeded.";
  1015. } else {
  1016. MS_LOG(EXCEPTION) << "Init selected strategy failed.";
  1017. }
  1018. // print the selected strategy
  1019. for (auto &op : entire_costgraph->GetOperators()) {
  1020. StrategyPtr s_strategy = op->selected_strategy();
  1021. MS_LOG(INFO) << op->name() << " : The strategy is:";
  1022. PrintStrategy(s_strategy);
  1023. }
  1024. return SUCCESS;
  1025. }
  1026. std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
  1027. std::vector<std::vector<std::string>> input_tensor_names) {
  1028. for (size_t j = 0; j < input_tensor_names.size(); j++) {
  1029. for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
  1030. if (it->first == input_tensor_names[j][k]) {
  1031. input_tensor_names[j][k] = it->second;
  1032. break;
  1033. }
  1034. }
  1035. }
  1036. return input_tensor_names;
  1037. }
  1038. CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
  1039. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  1040. if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) {
  1041. auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
  1042. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  1043. return nullptr;
  1044. }
  1045. auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  1046. while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) {
  1047. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  1048. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  1049. return nullptr;
  1050. }
  1051. prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  1052. }
  1053. return prev_cnode;
  1054. }
  1055. return nullptr;
  1056. }
  1057. Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  1058. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  1059. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  1060. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  1061. << entire_costgraph->GetOperators().size() << " operators.";
  1062. } else {
  1063. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  1064. }
  1065. } else {
  1066. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  1067. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  1068. << entire_costgraph->GetOperators().size() << " operators.";
  1069. } else {
  1070. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  1071. }
  1072. }
  1073. ReshapeCostCompute(all_nodes);
  1074. auto ops = entire_costgraph->GetOperators();
  1075. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  1076. auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
  1077. for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
  1078. input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
  1079. }
  1080. std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
  1081. std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
  1082. std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
  1083. graph = EliminateGraph(graph, eli_list, index_list);
  1084. size_t num_device = g_device_manager->DeviceNum();
  1085. double device_memory = entire_costgraph->GetDeviceMemory();
  1086. if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
  1087. MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
  1088. } else {
  1089. MS_LOG(ERROR) << "PartitionForAllDevices failed.";
  1090. return FAILED;
  1091. }
  1092. GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list);
  1093. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  1094. MS_LOG(INFO) << "Init selected strategy succeeded.";
  1095. } else {
  1096. MS_LOG(ERROR) << "Init selected strategy failed.";
  1097. return FAILED;
  1098. }
  1099. // print the selected strategy
  1100. for (auto &op : entire_costgraph->GetOperators()) {
  1101. StrategyPtr s_strategy = op->selected_strategy();
  1102. MS_LOG(INFO) << op->name() << " : The strategy is:";
  1103. PrintStrategy(s_strategy);
  1104. }
  1105. return SUCCESS;
  1106. }
  1107. } // namespace parallel
  1108. } // namespace mindspore