You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel.cc 52 kB

5 years ago
5 years ago
5 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/step_auto_parallel.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "ir/anf.h"
  28. #include "ir/param_value.h"
  29. #include "ir/tensor.h"
  30. #include "optimizer/opt.h"
  31. #include "optimizer/optimizer.h"
  32. #include "parallel/auto_parallel/dp_algo_costmodel.h"
  33. #include "parallel/auto_parallel/edge_costmodel.h"
  34. #include "parallel/auto_parallel/graph_costmodel.h"
  35. #include "parallel/auto_parallel/rec_core/rec_generate_strategy.h"
  36. #include "parallel/auto_parallel/rec_core/rec_parse_graph.h"
  37. #include "parallel/auto_parallel/rec_core/rec_partition.h"
  38. #include "parallel/context.h"
  39. #include "parallel/ops_info/tmp_identity_info.h"
  40. #include "parallel/ops_info/reshape_info.h"
  41. #include "parallel/step_parallel.h"
  42. #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
  43. #include "pipeline/parse/python_adapter.h"
  44. #include "pipeline/pipeline.h"
  45. namespace mindspore {
  46. namespace parallel {
  47. bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
  48. MS_EXCEPTION_IF_NULL(root);
  49. MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
  50. std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
  51. // assume no change to graph
  52. bool changes = false;
  53. // control whether use model_parallel mode
  54. if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
  55. root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
  56. return changes;
  57. }
  58. // check whether strategy_search_mode is valid
  59. std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
  60. if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
  61. // Setting searching mode: dynanic programming as default.
  62. strategy_search_mode = DYNAMIC_PROGRAMMING;
  63. MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
  64. }
  65. struct timeval start_time, end_time;
  66. (void)gettimeofday(&start_time, nullptr);
  67. if (MsContext::GetInstance()->save_graphs_flag()) {
  68. draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
  69. }
  70. MS_LOG(INFO) << "Now entering step auto parallel";
  71. TOTAL_OPS = 0;
  72. AnfNodePtr ret = root->get_return();
  73. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  74. if (ParallelInit() != SUCCESS) {
  75. MS_LOG(EXCEPTION) << "Parallel init failed";
  76. }
  77. // mark the forward cnodes, parallel only care these nodes
  78. MarkForwardCNode(root);
  79. if (FindCommunicationOp(all_nodes)) {
  80. MS_LOG(EXCEPTION) << "The graph contain communication op";
  81. }
  82. // search parallelization strategy
  83. if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
  84. if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
  85. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
  86. }
  87. } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
  88. if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
  89. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
  90. }
  91. } else {
  92. MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
  93. }
  94. (void)gettimeofday(&end_time, nullptr);
  95. uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  96. time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  97. MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
  98. root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
  99. return changes;
  100. }
  101. // Given the node, return whether each input is a parameter or a output of a operator.
  102. // The returned boolean vector should be the same order of the inputs, thus its implementation
  103. // is closely consistent with ExtractShape() in step_parallel.cc
  104. std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
  105. std::vector<bool> is_parameter;
  106. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  107. for (size_t i = 1; i < node_inputs.size(); ++i) {
  108. auto input = node_inputs[i];
  109. if (input->isa<Parameter>()) {
  110. auto input_parameter = input->cast<ParameterPtr>();
  111. if (input_parameter->has_default()) {
  112. bool requires_grad = input_parameter->default_param()->requires_grad();
  113. is_parameter.push_back(requires_grad);
  114. } else {
  115. is_parameter.push_back(false);
  116. }
  117. } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
  118. is_parameter.push_back(false);
  119. }
  120. }
  121. return is_parameter;
  122. }
  123. // Given the type, return the number of bytes to represent this type
  124. size_t GetLengthOfDataType(const TypePtr &type) {
  125. switch (type->type_id()) {
  126. case kNumberTypeBool:
  127. return sizeof(bool);
  128. case kNumberTypeInt8:
  129. return sizeof(int8_t);
  130. case kNumberTypeInt16:
  131. return sizeof(int16_t);
  132. case kNumberTypeInt32:
  133. return sizeof(int32_t);
  134. case kNumberTypeInt64:
  135. return sizeof(int64_t);
  136. case kNumberTypeUInt8:
  137. return sizeof(uint8_t);
  138. case kNumberTypeUInt16:
  139. return sizeof(uint16_t);
  140. case kNumberTypeUInt32:
  141. return sizeof(uint32_t);
  142. case kNumberTypeUInt64:
  143. return sizeof(uint64_t);
  144. case kNumberTypeFloat16:
  145. return sizeof(float) / 2;
  146. case kNumberTypeFloat32:
  147. return sizeof(float);
  148. case kNumberTypeFloat64:
  149. return sizeof(double);
  150. case kNumberTypeInt:
  151. return sizeof(int);
  152. case kNumberTypeUInt:
  153. return sizeof(unsigned int);
  154. case kNumberTypeFloat:
  155. return sizeof(float);
  156. default:
  157. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  158. }
  159. }
  160. size_t GetInputsTypeLen(const AnfNodePtr &input) {
  161. MS_EXCEPTION_IF_NULL(input);
  162. if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
  163. MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
  164. }
  165. size_t input_type_len = 0;
  166. auto type = input->Type();
  167. MS_EXCEPTION_IF_NULL(type);
  168. if (type->isa<mindspore::TensorType>()) {
  169. auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
  170. input_type_len = GetLengthOfDataType(input_element_type);
  171. } else {
  172. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  173. }
  174. return input_type_len;
  175. }
  176. std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
  177. MS_EXCEPTION_IF_NULL(node);
  178. std::vector<size_t> inputs_type_len;
  179. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  180. // extract input element length
  181. for (auto &input : node_inputs) {
  182. if (IsValueNode<RefKey>(input)) {
  183. auto func_graph = node->func_graph();
  184. MS_EXCEPTION_IF_NULL(func_graph);
  185. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
  186. if (parameters.size() != 1) {
  187. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  188. }
  189. inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
  190. } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
  191. // extract input shape from parameter and apply node
  192. inputs_type_len.push_back(GetInputsTypeLen(input));
  193. }
  194. }
  195. return inputs_type_len;
  196. }
  197. std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
  198. MS_EXCEPTION_IF_NULL(node);
  199. std::vector<TypePtr> outputs_type;
  200. // extract output element type
  201. auto primary_output_type = node->Type();
  202. MS_EXCEPTION_IF_NULL(primary_output_type);
  203. if (primary_output_type->isa<mindspore::Tuple>()) {
  204. // in this case, the output is a tuple
  205. auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
  206. auto elements = tuple_output_type->elements();
  207. for (auto &ele : elements) {
  208. if (ele->isa<mindspore::TensorType>()) {
  209. auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
  210. outputs_type.push_back(ele_element_type);
  211. } else {
  212. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  213. }
  214. }
  215. } else {
  216. // in this case, the output is a single tensor
  217. if (primary_output_type->isa<mindspore::TensorType>()) {
  218. auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
  219. outputs_type.push_back(element_type);
  220. } else {
  221. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  222. }
  223. }
  224. return outputs_type;
  225. }
  226. bool IsElementWiseOperator(const std::string &op_name) {
  227. static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU,
  228. SQRT, CAST, POW, EXP, LOG, COS,
  229. ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID};
  230. auto iter = elementwise_op.find(op_name);
  231. return (iter != elementwise_op.end());
  232. }
  233. bool IsSplittableOperator(const std::string &op_name) {
  234. // clang-format off
  235. static const std::set<std::string> splittable_op =
  236. {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
  237. FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
  238. REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
  239. MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
  240. LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
  241. STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2,
  242. SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
  243. // clang-format on
  244. auto iter = splittable_op.find(op_name);
  245. return (iter != splittable_op.end());
  246. }
  247. bool IsAutoParallelCareNode(const CNodePtr &cnode) {
  248. MS_EXCEPTION_IF_NULL(cnode);
  249. ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
  250. if (prim_node == nullptr) {
  251. return false;
  252. }
  253. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
  254. if (prim == nullptr) {
  255. return false;
  256. }
  257. bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
  258. if (bool_result) {
  259. MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
  260. } else if (prim->name() == CAST) {
  261. if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
  262. // Do not care CASTs from optimizer
  263. return false;
  264. }
  265. return true;
  266. }
  267. return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
  268. }
  269. OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
  270. MS_EXCEPTION_IF_NULL(prim);
  271. MS_EXCEPTION_IF_NULL(cnode);
  272. auto attrs = prim->attrs();
  273. std::vector<Shapes> shape_list = ExtractShape(cnode);
  274. if (shape_list.empty()) {
  275. MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
  276. }
  277. // Create an OperatorInfo instance
  278. OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
  279. MS_EXCEPTION_IF_NULL(operator_info);
  280. // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
  281. std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
  282. if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
  283. MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
  284. return nullptr;
  285. }
  286. // Set the data type for inputs and outputs of this OperatorInfo
  287. auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
  288. auto outputs_type = ExtractOutputTypeByNode(cnode);
  289. std::vector<size_t> outputs_type_length;
  290. outputs_type_length.reserve(outputs_type.size());
  291. std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
  292. GetLengthOfDataType);
  293. if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
  294. MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
  295. return nullptr;
  296. }
  297. if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
  298. MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
  299. return nullptr;
  300. }
  301. // When the 'inputs' contains numerical values for some operators, these values should be extracted from
  302. // ANF graph
  303. auto &inputs = cnode->inputs();
  304. std::vector<ValuePtr> input_value;
  305. for (size_t index = 1; index < inputs.size(); ++index) {
  306. if (inputs[index]->isa<ValueNode>()) {
  307. input_value.push_back(GetValueNode(inputs[index]));
  308. } else {
  309. input_value.emplace_back(nullptr);
  310. }
  311. }
  312. operator_info->set_input_value(input_value);
  313. operator_info->set_outputs_dtype(cnode->Type());
  314. operator_info->set_cnode(cnode);
  315. // key of strategy map
  316. std::string strategy_key_name = NodeParameterName(cnode);
  317. bool load_strategy_from_ckpt =
  318. StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
  319. // If no strategy has been configured for this operator, then candidate strategies are generated for
  320. // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
  321. // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
  322. if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
  323. // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
  324. // BatchParallelInfo operator
  325. operator_info->ComputeBatchSplitFlagList();
  326. if (operator_info->GenerateStrategies(0) != SUCCESS) {
  327. MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
  328. return nullptr;
  329. }
  330. } else {
  331. // In this case, the configured strategy should be extracted to help setting cost
  332. StrategyPtr strategyPtr;
  333. if (load_strategy_from_ckpt) {
  334. strategyPtr = (*stra_map)[strategy_key_name];
  335. } else {
  336. strategyPtr = parallel::ExtractStrategy(attrs);
  337. }
  338. if (strategyPtr != nullptr) {
  339. if (prim->name() == RESHAPE) {
  340. MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
  341. }
  342. // Set cost for this configured strategy
  343. if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
  344. MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
  345. } else if (FULLY_USE_DEVICES) {
  346. // If configured to fully use devices, then checking for the user-specified strategy
  347. int32_t used_devices = operator_info->used_devices();
  348. MS_EXCEPTION_IF_NULL(g_device_manager);
  349. auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
  350. // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
  351. if (used_devices == 1) {
  352. return operator_info;
  353. }
  354. // 'used_devices == -1' means that 'used_devices_' is not set
  355. if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) {
  356. MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
  357. << "but the specified strategy uses device: " << used_devices
  358. << ", total devices: " << total_device_num;
  359. }
  360. }
  361. }
  362. }
  363. return operator_info;
  364. }
  365. // Using CNode's UniqueIds to construct nodes
  366. Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  367. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  368. entire_costgraph = std::make_shared<CostGraph>();
  369. entire_costgraph->SetDeviceMemoryAndCostParameter();
  370. // The map from CNode's UniqueId to its operatorInfo
  371. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  372. // extract strategy from checkpoint for multi-train
  373. StrategyMap stra_map;
  374. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  375. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  376. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  377. }
  378. }
  379. // Step 1
  380. for (auto &node : all_nodes) {
  381. // NOTE: we only care about splittable Primitive operators
  382. auto cnode = node->cast<CNodePtr>();
  383. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  384. if (bool_result) {
  385. continue;
  386. }
  387. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  388. if (!IsAutoParallelCareNode(cnode)) {
  389. // Needed by rec_parser
  390. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  391. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  392. if (prev_cnode != nullptr) {
  393. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  394. }
  395. }
  396. continue;
  397. }
  398. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  399. MS_EXCEPTION_IF_NULL(prim);
  400. auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
  401. if (search_cnode == from_cnode_to_info.end()) {
  402. auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
  403. if (operator_info == nullptr) {
  404. return FAILED;
  405. }
  406. // Needed by rec_parser
  407. operator_info->set_type(prim->name());
  408. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  409. entire_costgraph->AddOperator(operator_info);
  410. (void)cnode->set_operator_info(operator_info);
  411. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  412. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  413. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  414. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  415. // Needed by rec_parser
  416. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  417. } else {
  418. // Two CNODEs' UniqueIds should not be equal
  419. MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
  420. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  421. << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
  422. }
  423. }
  424. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  425. return SUCCESS;
  426. }
  427. // Using CNode's UniqueIdThroughCopys to construct nodes
  428. Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  429. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  430. entire_costgraph = std::make_shared<CostGraph>();
  431. entire_costgraph->SetDeviceMemoryAndCostParameter();
  432. // The map from CNode's UniqueIdThroughCopy to its operatorInfo
  433. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  434. // extract strategy from checkpoint for multi-train
  435. StrategyMap stra_map;
  436. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  437. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  438. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  439. }
  440. }
  441. for (auto &node : all_nodes) {
  442. // NOTE: we only care about splittable Primitive operators
  443. auto cnode = node->cast<CNodePtr>();
  444. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  445. if (bool_result) {
  446. continue;
  447. }
  448. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  449. if (!IsAutoParallelCareNode(cnode)) {
  450. // Needed by rec_parser
  451. if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
  452. auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
  453. if (prev_cnode != nullptr) {
  454. entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
  455. }
  456. }
  457. continue;
  458. }
  459. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  460. // Find the operatorInfo if it exists
  461. auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
  462. if (search_cnode == from_cnode_to_info.end()) {
  463. // In this case, the corresponding OperatorInfo is not created, create the new one.
  464. auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
  465. if (operator_info == nullptr) {
  466. return FAILED;
  467. }
  468. // Needed by rec_parser
  469. operator_info->set_type(prim->name());
  470. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  471. entire_costgraph->AddOperator(operator_info);
  472. (void)cnode->set_operator_info(operator_info);
  473. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  474. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  475. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  476. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  477. // Needed by rec_parser
  478. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  479. } else {
  480. auto current_op_ptr = search_cnode->second;
  481. if (current_op_ptr == nullptr) {
  482. MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
  483. } else {
  484. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  485. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  486. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  487. if (is_find_wrong) {
  488. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  489. << " does not match the Prim: " << prim->name();
  490. }
  491. (void)cnode->set_operator_info(current_op_ptr);
  492. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  493. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  494. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  495. }
  496. }
  497. }
  498. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  499. return SUCCESS;
  500. }
  501. void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
  502. // Step 2
  503. MS_LOG(INFO) << "Constructing edges for cost graph begins.";
  504. for (auto &node : all_nodes) {
  505. auto cnode = node->cast<CNodePtr>();
  506. bool bool_result_cnode = (cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0));
  507. if (bool_result_cnode) {
  508. continue;
  509. }
  510. auto &inputs = cnode->inputs();
  511. ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
  512. if (!IsAutoParallelCareNode(cnode)) {
  513. continue;
  514. }
  515. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  516. size_t edge_count = 0;
  517. for (size_t i = 1; i < inputs.size(); ++i) {
  518. auto prev_cnode = inputs[i]->cast<CNodePtr>();
  519. bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  520. if (bool_result_prev_cnode) {
  521. continue;
  522. }
  523. ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  524. PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  525. size_t output_index = 0;
  526. bool bool_result =
  527. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  528. while (bool_result) {
  529. if (IsAutoParallelCareNode(prev_cnode)) {
  530. std::string edge_name =
  531. prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
  532. // If the edge between these two operators already has been added, then the edge will not be added again.
  533. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
  534. break;
  535. }
  536. EdgePtr edge_ptr;
  537. MS_LOG(INFO) << "Creating edge: " << edge_name;
  538. bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
  539. (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
  540. if (follow_strategy) {
  541. // Redistribution in not allowed on the edge.
  542. // Elementwise operators have the same strategy as their previous operators.
  543. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  544. output_index, i - 1, false, true);
  545. } else {
  546. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  547. output_index, i - 1, false);
  548. }
  549. // Init costs for this edge
  550. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  551. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  552. }
  553. cnode->operator_info()->AddPrevEdge(edge_ptr);
  554. prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
  555. entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
  556. MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
  557. << cnode->operator_info()->name();
  558. edge_count++;
  559. break;
  560. } else if (prev_prim->name() == TUPLE_GETITEM) {
  561. // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
  562. // this 'tuple_getitem'
  563. MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
  564. output_index = IntToSize(GetValue<int>(GetValueNode(prev_cnode->input(2))));
  565. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  566. bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  567. if (bool_result_tuple) {
  568. break;
  569. }
  570. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  571. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  572. if (!IsAutoParallelCareNode(prev_cnode)) {
  573. MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
  574. }
  575. MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
  576. << "and creating an edge between the Operator before "
  577. << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
  578. } else if (prev_prim->name() == DEPEND) {
  579. // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
  580. // this 'depend'
  581. MS_LOG(INFO) << "Jumping the 'depend' operator.";
  582. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  583. bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  584. if (bool_result_depend) {
  585. break;
  586. }
  587. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  588. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  589. MS_LOG(INFO) << "Jumped the 'depend' operator, "
  590. << "and creating an edge between the Operator before "
  591. << "'depend' and the Operator after 'depend'.";
  592. }
  593. bool_result =
  594. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  595. }
  596. }
  597. MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
  598. }
  599. MS_LOG(INFO) << "Constructing edges for cost graph ends.";
  600. }
  601. std::pair<AnfNodePtr, std::vector<AnfNodePtr>> CNodeWithRefKeys(const AnfNodePtr &cnode) {
  602. MS_EXCEPTION_IF_NULL(cnode);
  603. std::vector<AnfNodePtr> refkeys;
  604. if (cnode->isa<CNode>()) {
  605. auto cnode_ptr = cnode->cast<CNodePtr>();
  606. auto inputs = cnode_ptr->inputs();
  607. for (auto &one_input : inputs) {
  608. if (IsValueNode<RefKey>(one_input)) {
  609. refkeys.push_back(one_input);
  610. }
  611. }
  612. if (refkeys.size() >= 1) {
  613. return std::make_pair(cnode, refkeys);
  614. }
  615. }
  616. return {nullptr, refkeys};
  617. }
  618. void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
  619. // Step 3
  620. for (auto &node : all_nodes) {
  621. auto cnode_with_refkeys = CNodeWithRefKeys(node);
  622. if ((!node->isa<Parameter>()) && (cnode_with_refkeys.first == nullptr)) {
  623. continue;
  624. }
  625. std::string parameter_name;
  626. AnfNodePtr target_parameter = nullptr;
  627. AnfNodeIndexSet target_set;
  628. if (cnode_with_refkeys.first != nullptr) {
  629. // Dealing with the RefKey case
  630. auto refkeys = cnode_with_refkeys.second;
  631. auto cnode = cnode_with_refkeys.first;
  632. auto cnode_ptr = cnode->cast<CNodePtr>();
  633. if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
  634. continue;
  635. }
  636. if (!IsAutoParallelCareNode(cnode_ptr)) {
  637. continue;
  638. }
  639. if (refkeys.size() > 1) {
  640. MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
  641. }
  642. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  643. auto cnode_func_graph = cnode->func_graph();
  644. MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
  645. // Find the RefKey being used
  646. auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
  647. for (auto &candidate : candidate_set_by_refkey) {
  648. auto candidate_node = candidate.first;
  649. auto c = candidate_node->cast<CNodePtr>();
  650. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  651. continue;
  652. }
  653. if (!IsAutoParallelCareNode(c)) {
  654. continue;
  655. }
  656. target_set.add(candidate);
  657. }
  658. // Find the corresponding Parameter being used
  659. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
  660. if (parameters.size() != 1) {
  661. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  662. }
  663. parameter_name = parameters[0]->cast<ParameterPtr>()->name();
  664. target_parameter = parameters[0];
  665. auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
  666. for (auto &candidate : candidate_set_by_para) {
  667. auto candidate_node = candidate.first;
  668. auto c = candidate_node->cast<CNodePtr>();
  669. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  670. continue;
  671. }
  672. if (!IsAutoParallelCareNode(c)) {
  673. continue;
  674. }
  675. (void)target_set.insert(candidate);
  676. }
  677. } else if (node->isa<Parameter>()) {
  678. // Dealing with the Parameter case
  679. MS_EXCEPTION_IF_NULL(node->func_graph());
  680. MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
  681. auto candidate_set = node->func_graph()->manager()->node_users()[node];
  682. for (auto &candidate : candidate_set) {
  683. auto candidate_node = candidate.first;
  684. auto c = candidate_node->cast<CNodePtr>();
  685. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  686. continue;
  687. }
  688. if (!IsAutoParallelCareNode(c)) {
  689. continue;
  690. }
  691. (void)target_set.insert(candidate);
  692. }
  693. // In this case, node is a Parameter
  694. parameter_name = node->cast<ParameterPtr>()->name();
  695. target_parameter = node;
  696. }
  697. if (target_set.size() <= 1) {
  698. continue;
  699. }
  700. // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
  701. std::set<std::string> target_without_duplicate;
  702. for (auto &target : target_set) {
  703. auto target_cnode = target.first->cast<CNodePtr>();
  704. auto input_index = target.second;
  705. (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
  706. }
  707. if (target_without_duplicate.size() <= 1) {
  708. continue;
  709. }
  710. // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
  711. OperatorInfoPtr tmp_identity_ptr;
  712. bool new_identity = false;
  713. std::string tmp_identity_name;
  714. auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
  715. if (returned_identity != nullptr) {
  716. // In this case, the TmpIdentityInfo instance has already been created
  717. new_identity = false;
  718. tmp_identity_ptr = returned_identity;
  719. tmp_identity_name = tmp_identity_ptr->name();
  720. } else {
  721. // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
  722. new_identity = true;
  723. // 1) extract input shape from this Parameter
  724. MS_EXCEPTION_IF_NULL(target_parameter);
  725. AbstractBasePtr abstract = target_parameter->abstract();
  726. if (abstract == nullptr) {
  727. MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
  728. }
  729. auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
  730. if (input_shape == nullptr) {
  731. MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
  732. }
  733. std::vector<int> shape_int = input_shape->shape();
  734. Shape shape;
  735. (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
  736. [](int sub_shape) { return static_cast<int32_t>(sub_shape); });
  737. Shapes inputs_shape = {shape};
  738. Shapes outputs_shape = {shape};
  739. // 2) init the attr
  740. std::unordered_map<std::string, ValuePtr> attr = {};
  741. // Create the TmpIdentity instance
  742. tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
  743. tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
  744. TOTAL_OPS++;
  745. tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
  746. // Set the parameter and type lengths for inputs and outputs
  747. std::vector<bool> is_parameter;
  748. auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
  749. MS_EXCEPTION_IF_NULL(casted_target_parameter);
  750. if (casted_target_parameter->has_default()) {
  751. bool requires_grad = casted_target_parameter->default_param()->requires_grad();
  752. is_parameter.push_back(requires_grad);
  753. } else {
  754. is_parameter.push_back(false);
  755. }
  756. if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
  757. MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
  758. }
  759. auto node_type = target_parameter->Type();
  760. if (node_type->isa<mindspore::TensorType>()) {
  761. auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
  762. std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
  763. if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
  764. MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
  765. }
  766. } else {
  767. MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
  768. }
  769. // Generate strategies for this TmpIdentityInfo instance;
  770. if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
  771. MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
  772. }
  773. }
  774. // A flag recording whether new edges have been created or not
  775. bool add_identity_edge = false;
  776. // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
  777. for (auto &target : target_set) {
  778. auto target_cnode = target.first->cast<CNodePtr>();
  779. auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
  780. auto input_index = target.second;
  781. std::string edge_name =
  782. std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
  783. // If the edge between these two operators already has been added, then the edge will not be added again.
  784. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
  785. continue;
  786. }
  787. std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
  788. edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
  789. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  790. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  791. }
  792. target_cnode->operator_info()->AddPrevEdge(edge_ptr);
  793. tmp_identity_ptr->AddSuccEdge(edge_ptr);
  794. entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
  795. MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
  796. << target_cnode->operator_info()->name();
  797. add_identity_edge = true;
  798. }
  799. if (new_identity && add_identity_edge) {
  800. // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
  801. entire_costgraph->AddOperator(tmp_identity_ptr);
  802. }
  803. }
  804. }
  805. bool FindReshape(const CNodePtr &cnode) {
  806. if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
  807. return false;
  808. }
  809. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  810. if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
  811. return false;
  812. }
  813. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  814. MS_EXCEPTION_IF_NULL(prim);
  815. OperatorInfoPtr operator_info = cnode->operator_info();
  816. if (operator_info == nullptr) {
  817. MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
  818. }
  819. if (prim->name() != RESHAPE) {
  820. return false;
  821. }
  822. return true;
  823. }
  824. // find previous node, then obtain its strategy_cost_ vector to get its layout vector.
  825. bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) {
  826. // if previous node is a parameter, handle it in the outsize.
  827. if (node->isa<Parameter>()) {
  828. return false;
  829. }
  830. if (!node->isa<CNode>()) {
  831. return false;
  832. }
  833. CNodePtr cnode = node->cast<CNodePtr>();
  834. if (!IsValueNode<Primitive>(cnode->input(0))) {
  835. return false;
  836. }
  837. if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
  838. *pre_operator_info = cnode->operator_info();
  839. *out_index = 0;
  840. return true;
  841. }
  842. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  843. PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
  844. if (prim->name() == TUPLE_GETITEM) {
  845. *out_index = GetTupleGetItemIndex(cnode);
  846. // find tuple_get_item's previous node
  847. auto pre_node = cnode->input(1);
  848. if (!pre_node->isa<CNode>()) {
  849. MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
  850. }
  851. CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
  852. if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
  853. *pre_operator_info = pre_cnode->operator_info();
  854. return true;
  855. }
  856. return false;
  857. }
  858. for (size_t index = 0; index < cnode->inputs().size(); ++index) {
  859. if (prim->name() == DEPEND && index != 1) {
  860. continue;
  861. }
  862. if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
  863. continue;
  864. }
  865. return true;
  866. }
  867. MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
  868. return false;
  869. }
  870. // find next node, then obtain its strategy_cost_ vector to get its layout vector.
  871. // if reshape's output connect to several primitive, return the first layout found
  872. bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) {
  873. MS_EXCEPTION_IF_NULL(cnode);
  874. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  875. FuncGraphManagerPtr manager = cnode->func_graph()->manager();
  876. MS_EXCEPTION_IF_NULL(manager);
  877. AnfNodeIndexSet node_set = manager->node_users()[cnode];
  878. for (auto &node_pair : node_set) {
  879. CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
  880. if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
  881. continue;
  882. }
  883. ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
  884. MS_EXCEPTION_IF_NULL(prim_anf_node);
  885. PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
  886. MS_EXCEPTION_IF_NULL(node_prim);
  887. MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
  888. if (node_prim->name() == DEPEND && node_pair.second != 1) {
  889. continue;
  890. }
  891. if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
  892. MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
  893. *next_operator_info = use_apply->operator_info();
  894. *in_index = node_pair.second - 1;
  895. return true;
  896. }
  897. MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
  898. << " " << (use_apply->operator_info() != nullptr);
  899. if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
  900. return true;
  901. }
  902. }
  903. return false;
  904. }
  905. void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
  906. for (auto node : all_nodes) {
  907. auto cnode = node->cast<CNodePtr>();
  908. if (!FindReshape(cnode)) {
  909. continue;
  910. }
  911. MS_ASSERT(cnode->inputs().size() == 3);
  912. // get previous node's strategy_cost_
  913. auto pre_node = cnode->input(1);
  914. int32_t out_index = 0;
  915. OperatorInfoPtr pre_operator_info;
  916. std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
  917. if (pre_node->isa<Parameter>()) {
  918. OperatorInfoPtr operator_info = cnode->operator_info();
  919. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  920. reshape_info->SetCostForReshapeWithParameter();
  921. pre_operator_info = reshape_info;
  922. pre_stra_costs = reshape_info->strategy_cost();
  923. } else {
  924. if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
  925. MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
  926. }
  927. pre_stra_costs = pre_operator_info->strategy_cost();
  928. }
  929. // get next node's strategy_cost_
  930. int32_t in_index = 0;
  931. OperatorInfoPtr next_operator_info;
  932. std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
  933. bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
  934. if (!find_next_node) {
  935. MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
  936. }
  937. // set input_layout and output_layout for reshape.
  938. // init reshape and set cost for each input_layout and output_layout.
  939. OperatorInfoPtr operator_info = cnode->operator_info();
  940. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  941. reshape_info->set_pre_operator_name(pre_operator_info->name());
  942. reshape_info->set_pre_operator_index(out_index);
  943. if (find_next_node) {
  944. next_stra_costs = next_operator_info->strategy_cost();
  945. reshape_info->set_next_operator_name(next_operator_info->name());
  946. reshape_info->set_next_operator_index(in_index);
  947. }
  948. bool is_prev_param = pre_node->isa<Parameter>();
  949. if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
  950. SUCCESS) {
  951. MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!";
  952. }
  953. }
  954. }
  955. Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  956. // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
  957. // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
  958. // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
  959. // for each OperatorInfo;
  960. // Step 1.1: Deal with 'Reshape':
  961. // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
  962. // layout as its output layout.
  963. // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
  964. // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
  965. // for each edge, based on the strategies of two OperatorInfos;
  966. // Step 3: Augment the costgraph:
  967. // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
  968. // operator for this Parameter, and add an edge for the use of this Parameter by each
  969. // subsequent operator;
  970. // Step 3.1: Calculate memory usage:
  971. // note the memory usage calculation is different in training phase and inference phase.
  972. // Step 4: Run the Dynamic Programming algorithm:
  973. // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
  974. // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
  975. // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
  976. // runs on each of them.
  977. //
  978. // OUTPUT: the determined strategy for each operator.
  979. // Step 1
  980. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  981. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  982. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  983. << entire_costgraph->GetOperators().size() << " operators.";
  984. } else {
  985. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  986. }
  987. } else {
  988. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  989. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  990. << entire_costgraph->GetOperators().size() << " operators.";
  991. } else {
  992. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  993. }
  994. }
  995. // Step 1.1
  996. ReshapeCostCompute(all_nodes);
  997. // Step 2
  998. ConstructCostGraphEdges(all_nodes);
  999. MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  1000. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  1001. // Step 3: Augment the costgraph.
  1002. AugmentCostGraph(all_nodes);
  1003. MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
  1004. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  1005. // Step 3.1: Calculate the memory usage
  1006. if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
  1007. MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
  1008. }
  1009. // Step 4: run DP algorithm on the costgraph.
  1010. if (GetStrategy(entire_costgraph) != SUCCESS) {
  1011. MS_LOG(ERROR) << "Strategy search for cost-graph fails";
  1012. return FAILED;
  1013. }
  1014. MS_LOG(INFO) << "Searching strategy succeeded.";
  1015. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  1016. MS_LOG(INFO) << "Init selected strategy succeeded.";
  1017. } else {
  1018. MS_LOG(EXCEPTION) << "Init selected strategy failed.";
  1019. }
  1020. // print the selected strategy
  1021. for (auto &op : entire_costgraph->GetOperators()) {
  1022. StrategyPtr s_strategy = op->selected_strategy();
  1023. MS_LOG(INFO) << op->name() << " : The strategy is:";
  1024. PrintStrategy(s_strategy);
  1025. }
  1026. return SUCCESS;
  1027. }
  1028. std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
  1029. std::vector<std::vector<std::string>> input_tensor_names) {
  1030. for (size_t j = 0; j < input_tensor_names.size(); j++) {
  1031. for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
  1032. if (it->first == input_tensor_names[j][k]) {
  1033. input_tensor_names[j][k] = it->second;
  1034. break;
  1035. }
  1036. }
  1037. }
  1038. return input_tensor_names;
  1039. }
  1040. CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
  1041. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  1042. if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) {
  1043. auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
  1044. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  1045. return nullptr;
  1046. }
  1047. auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  1048. while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) {
  1049. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  1050. if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
  1051. return nullptr;
  1052. }
  1053. prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  1054. }
  1055. return prev_cnode;
  1056. }
  1057. return nullptr;
  1058. }
  1059. Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  1060. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  1061. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  1062. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  1063. << entire_costgraph->GetOperators().size() << " operators.";
  1064. } else {
  1065. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  1066. }
  1067. } else {
  1068. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  1069. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  1070. << entire_costgraph->GetOperators().size() << " operators.";
  1071. } else {
  1072. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  1073. }
  1074. }
  1075. ReshapeCostCompute(all_nodes);
  1076. auto ops = entire_costgraph->GetOperators();
  1077. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  1078. auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
  1079. for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
  1080. input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
  1081. }
  1082. std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
  1083. std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
  1084. std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
  1085. graph = EliminateGraph(graph, eli_list, index_list);
  1086. size_t num_device = g_device_manager->DeviceNum();
  1087. double device_memory = entire_costgraph->GetDeviceMemory();
  1088. if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
  1089. MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
  1090. } else {
  1091. MS_LOG(ERROR) << "PartitionForAllDevices failed.";
  1092. return FAILED;
  1093. }
  1094. GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list);
  1095. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  1096. MS_LOG(INFO) << "Init selected strategy succeeded.";
  1097. } else {
  1098. MS_LOG(ERROR) << "Init selected strategy failed.";
  1099. return FAILED;
  1100. }
  1101. // print the selected strategy
  1102. for (auto &op : entire_costgraph->GetOperators()) {
  1103. StrategyPtr s_strategy = op->selected_strategy();
  1104. MS_LOG(INFO) << op->name() << " : The strategy is:";
  1105. PrintStrategy(s_strategy);
  1106. }
  1107. return SUCCESS;
  1108. }
  1109. } // namespace parallel
  1110. } // namespace mindspore