You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel.cc 53 kB

6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/step_auto_parallel.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "ir/anf.h"
  28. #include "ir/param_value_py.h"
  29. #include "ir/meta_tensor.h"
  30. #include "optimizer/opt.h"
  31. #include "optimizer/optimizer.h"
  32. #include "parallel/auto_parallel/dp_algo_costmodel.h"
  33. #include "parallel/auto_parallel/edge_costmodel.h"
  34. #include "parallel/auto_parallel/graph_costmodel.h"
  35. #include "parallel/auto_parallel/rec_core/rec_generate_strategy.h"
  36. #include "parallel/auto_parallel/rec_core/rec_parse_graph.h"
  37. #include "parallel/auto_parallel/rec_core/rec_partition.h"
  38. #include "parallel/context.h"
  39. #include "parallel/ops_info/tmp_identity_info.h"
  40. #include "parallel/ops_info/reshape_info.h"
  41. #include "parallel/step_parallel.h"
  42. #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
  43. #include "pipeline/parse/python_adapter.h"
  44. #include "pipeline/pipeline.h"
  45. namespace mindspore {
  46. namespace parallel {
  47. // splittable_op_ will continuously be updated
  48. std::vector<std::string> splittable_op_ = {MATMUL,
  49. GELU,
  50. TANH,
  51. SOFTMAX,
  52. LOG_SOFTMAX,
  53. ACTIVATION,
  54. PRELU,
  55. FLOORDIV,
  56. L2_NORMALIZE,
  57. TRANSPOSE,
  58. RESHAPE,
  59. TENSOR_ADD,
  60. SUB,
  61. MUL,
  62. DIV,
  63. GREATER,
  64. MAXPOOL,
  65. MAXPOOLV2,
  66. VIRTUAL_DATA_SET,
  67. SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
  68. RELU,
  69. ONEHOT,
  70. DROPOUT_DO_MASK,
  71. REDUCE_MAX,
  72. REDUCE_MIN,
  73. ARGMAXWITHVALUE,
  74. ARGMINWITHVALUE,
  75. REDUCE_SUM,
  76. CONV2D,
  77. FUSE_BATCH_NORM,
  78. POOLING,
  79. SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
  80. SIGMOID_CROSS_ENTROPY_WITH_LOGITS,
  81. MAX_POOL_WITH_ARGMAX,
  82. SIMPLE_MEAN,
  83. FLATTEN,
  84. BATCH_NORM,
  85. LAYER_NORM,
  86. BIAS_ADD,
  87. ASSIGN_SUB,
  88. COS,
  89. ACOS,
  90. EXP,
  91. LOG,
  92. REDUCE_MEAN,
  93. REAL_DIV,
  94. SIGMOID,
  95. POW,
  96. MAXIMUM,
  97. MINIMUM,
  98. EQUAL,
  99. NOT_EQUAL,
  100. LOGICALNOT,
  101. GATHERV2,
  102. STRIDEDSLICE,
  103. SQRT,
  104. GET_NEXT,
  105. CAST,
  106. NEG,
  107. SQUARE,
  108. BATCH_MATMUL,
  109. EXPAND_DIMS,
  110. SQUEEZE};
  111. std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, CAST,
  112. POW, EXP, LOG, COS, ACOS, LOGICALNOT, NEG, SQUARE};
  113. bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
  114. MS_EXCEPTION_IF_NULL(root);
  115. MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
  116. std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
  117. // assume no change to graph
  118. bool changes = false;
  119. // control whether use model_parallel mode
  120. if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
  121. root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
  122. return changes;
  123. }
  124. // check whether strategy_search_mode is valid
  125. std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
  126. if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
  127. // Setting searching mode: dynanic programming as default.
  128. strategy_search_mode = DYNAMIC_PROGRAMMING;
  129. MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
  130. }
  131. struct timeval start_time, end_time;
  132. (void)gettimeofday(&start_time, nullptr);
  133. if (MsContext::GetInstance()->save_graphs_flag()) {
  134. draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
  135. }
  136. MS_LOG(INFO) << "Now entering step auto parallel";
  137. TOTAL_OPS = 0;
  138. AnfNodePtr ret = root->get_return();
  139. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  140. if (ParallelInit() != SUCCESS) {
  141. MS_LOG(EXCEPTION) << "Parallel init failed";
  142. }
  143. // mark the forward cnodes, parallel only care these nodes
  144. MarkForwardCNode(root);
  145. if (FindCommunicationOp(all_nodes)) {
  146. MS_LOG(EXCEPTION) << "The graph contain communication op";
  147. }
  148. // search parallelization strategy
  149. if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
  150. if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
  151. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
  152. }
  153. } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
  154. if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
  155. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
  156. }
  157. } else {
  158. MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
  159. }
  160. (void)gettimeofday(&end_time, nullptr);
  161. uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  162. time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  163. MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
  164. root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true;
  165. return changes;
  166. }
  167. // Given the node, return whether each input is a parameter or a output of a operator.
  168. // The returned boolean vector should be the same order of the inputs, thus its implementation
  169. // is closely consistent with ExtractShape() in step_parallel.cc
  170. std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
  171. std::vector<bool> is_parameter;
  172. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  173. for (size_t i = 1; i < node_inputs.size(); ++i) {
  174. auto input = node_inputs[i];
  175. if (input->isa<Parameter>()) {
  176. auto input_parameter = input->cast<ParameterPtr>();
  177. if (input_parameter->has_default()) {
  178. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
  179. bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
  180. is_parameter.push_back(require_grad);
  181. } else {
  182. is_parameter.push_back(false);
  183. }
  184. } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
  185. is_parameter.push_back(false);
  186. }
  187. }
  188. return is_parameter;
  189. }
  190. // Given the type, return the number of bytes to represent this type
  191. size_t GetLengthOfDataType(const TypePtr &type) {
  192. switch (type->type_id()) {
  193. case kNumberTypeBool:
  194. return sizeof(bool);
  195. case kNumberTypeInt8:
  196. return sizeof(int8_t);
  197. case kNumberTypeInt16:
  198. return sizeof(int16_t);
  199. case kNumberTypeInt32:
  200. return sizeof(int32_t);
  201. case kNumberTypeInt64:
  202. return sizeof(int64_t);
  203. case kNumberTypeUInt8:
  204. return sizeof(uint8_t);
  205. case kNumberTypeUInt16:
  206. return sizeof(uint16_t);
  207. case kNumberTypeUInt32:
  208. return sizeof(uint32_t);
  209. case kNumberTypeUInt64:
  210. return sizeof(uint64_t);
  211. case kNumberTypeFloat16:
  212. return sizeof(float) / 2;
  213. case kNumberTypeFloat32:
  214. return sizeof(float);
  215. case kNumberTypeFloat64:
  216. return sizeof(double);
  217. case kNumberTypeInt:
  218. return sizeof(int);
  219. case kNumberTypeUInt:
  220. return sizeof(unsigned int);
  221. case kNumberTypeFloat:
  222. return sizeof(float);
  223. default:
  224. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  225. }
  226. }
  227. size_t GetInputsTypeLen(const AnfNodePtr &input) {
  228. MS_EXCEPTION_IF_NULL(input);
  229. if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
  230. MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
  231. }
  232. size_t input_type_len = 0;
  233. auto type = input->Type();
  234. MS_EXCEPTION_IF_NULL(type);
  235. if (type->isa<mindspore::TensorType>()) {
  236. auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
  237. input_type_len = GetLengthOfDataType(input_element_type);
  238. } else {
  239. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  240. }
  241. return input_type_len;
  242. }
  243. std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
  244. MS_EXCEPTION_IF_NULL(node);
  245. std::vector<size_t> inputs_type_len;
  246. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  247. // extract input element length
  248. for (auto &input : node_inputs) {
  249. if (IsValueNode<RefKey>(input)) {
  250. auto func_graph = node->func_graph();
  251. MS_EXCEPTION_IF_NULL(func_graph);
  252. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
  253. if (parameters.size() != 1) {
  254. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  255. }
  256. inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
  257. } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
  258. // extract input shape from parameter and apply node
  259. inputs_type_len.push_back(GetInputsTypeLen(input));
  260. }
  261. }
  262. return inputs_type_len;
  263. }
  264. std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
  265. MS_EXCEPTION_IF_NULL(node);
  266. std::vector<TypePtr> outputs_type;
  267. // extract output element type
  268. auto primary_output_type = node->Type();
  269. MS_EXCEPTION_IF_NULL(primary_output_type);
  270. if (primary_output_type->isa<mindspore::Tuple>()) {
  271. // in this case, the output is a tuple
  272. auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
  273. auto elements = tuple_output_type->elements();
  274. for (auto &ele : elements) {
  275. if (ele->isa<mindspore::TensorType>()) {
  276. auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
  277. outputs_type.push_back(ele_element_type);
  278. } else {
  279. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  280. }
  281. }
  282. } else {
  283. // in this case, the output is a single tensor
  284. if (primary_output_type->isa<mindspore::TensorType>()) {
  285. auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
  286. outputs_type.push_back(element_type);
  287. } else {
  288. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  289. }
  290. }
  291. return outputs_type;
  292. }
  293. bool IsElementWiseOperator(const std::string &op_name) {
  294. auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name);
  295. return (iter != elementwise_op_.end());
  296. }
  297. bool IsSplittableOperator(const std::string &op_name) {
  298. std::vector<std::string>::iterator iter;
  299. iter = std::find(splittable_op_.begin(), splittable_op_.end(), op_name);
  300. return (iter != splittable_op_.end());
  301. }
  302. bool IsAutoParallelCareNode(const CNodePtr &cnode) {
  303. MS_EXCEPTION_IF_NULL(cnode);
  304. ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
  305. if (prim_node == nullptr) {
  306. return false;
  307. }
  308. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
  309. if (prim == nullptr) {
  310. return false;
  311. }
  312. bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
  313. if (bool_result) {
  314. MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
  315. } else if (prim->name() == CAST) {
  316. return true;
  317. }
  318. return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
  319. }
  320. OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
  321. MS_EXCEPTION_IF_NULL(prim);
  322. MS_EXCEPTION_IF_NULL(cnode);
  323. auto attrs = prim->attrs();
  324. std::vector<Shapes> shape_list = ExtractShape(cnode);
  325. if (shape_list.empty()) {
  326. MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
  327. }
  328. // Create an OperatorInfo instance
  329. OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
  330. MS_EXCEPTION_IF_NULL(operator_info);
  331. // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
  332. std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
  333. if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
  334. MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
  335. return nullptr;
  336. }
  337. // Set the data type for inputs and outputs of this OperatorInfo
  338. auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
  339. auto outputs_type = ExtractOutputTypeByNode(cnode);
  340. std::vector<size_t> outputs_type_length;
  341. outputs_type_length.reserve(outputs_type.size());
  342. std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
  343. GetLengthOfDataType);
  344. if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
  345. MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
  346. return nullptr;
  347. }
  348. if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
  349. MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
  350. return nullptr;
  351. }
  352. // When the 'inputs' contains numerical values for some operators, these values should be extracted from
  353. // ANF graph
  354. auto &inputs = cnode->inputs();
  355. std::vector<ValuePtr> input_value;
  356. for (size_t index = 1; index < inputs.size(); ++index) {
  357. if (inputs[index]->isa<ValueNode>()) {
  358. input_value.push_back(GetValueNode(inputs[index]));
  359. } else {
  360. input_value.emplace_back(nullptr);
  361. }
  362. }
  363. operator_info->set_input_value(input_value);
  364. operator_info->set_outputs_dtype(cnode->Type());
  365. operator_info->set_cnode(cnode);
  366. // key of strategy map
  367. std::string strategy_key_name = NodeParameterName(cnode);
  368. bool load_strategy_from_ckpt =
  369. StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
  370. // If no strategy has been configured for this operator, then candidate strategies are generated for
  371. // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
  372. // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
  373. if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
  374. // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
  375. // BatchParallelInfo operator
  376. operator_info->ComputeBatchSplitFlagList();
  377. if (operator_info->GenerateStrategies(0) != SUCCESS) {
  378. MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
  379. return nullptr;
  380. }
  381. } else {
  382. // In this case, the configured strategy should be extracted to help setting cost
  383. StrategyPtr strategyPtr;
  384. if (load_strategy_from_ckpt) {
  385. strategyPtr = (*stra_map)[strategy_key_name];
  386. } else {
  387. strategyPtr = parallel::ExtractStrategy(attrs);
  388. }
  389. if (strategyPtr != nullptr) {
  390. if (prim->name() == RESHAPE) {
  391. MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
  392. }
  393. // Set cost for this configured strategy
  394. if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
  395. MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
  396. } else if (FULLY_USE_DEVICES) {
  397. // If configured to fully use devices, then checking for the user-specified strategy
  398. int32_t used_devices = operator_info->used_devices();
  399. MS_EXCEPTION_IF_NULL(g_device_manager);
  400. auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
  401. // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
  402. if (used_devices == 1) {
  403. return operator_info;
  404. }
  405. // 'used_devices == -1' means that 'used_devices_' is not set
  406. if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) {
  407. MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
  408. << "but the specified strategy uses device: " << used_devices
  409. << ", total devices: " << total_device_num;
  410. }
  411. }
  412. }
  413. }
  414. return operator_info;
  415. }
  416. // Using CNode's UniqueIds to construct nodes
  417. Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  418. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  419. entire_costgraph = std::make_shared<CostGraph>();
  420. entire_costgraph->SetDeviceMemoryAndCostParameter();
  421. // The map from CNode's UniqueId to its operatorInfo
  422. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  423. // extract strategy from checkpoint for multi-train
  424. StrategyMap stra_map;
  425. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  426. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  427. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  428. }
  429. }
  430. // Step 1
  431. for (auto &node : all_nodes) {
  432. // NOTE: we only care about splittable Primitive operators
  433. auto cnode = node->cast<CNodePtr>();
  434. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  435. if (bool_result) {
  436. continue;
  437. }
  438. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  439. if (!IsAutoParallelCareNode(cnode)) {
  440. continue;
  441. }
  442. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  443. MS_EXCEPTION_IF_NULL(prim);
  444. auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
  445. if (search_cnode == from_cnode_to_info.end()) {
  446. auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
  447. if (operator_info == nullptr) {
  448. return FAILED;
  449. }
  450. // Needed by rec_parser
  451. operator_info->set_type(prim->name());
  452. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  453. entire_costgraph->AddOperator(operator_info);
  454. (void)cnode->set_operator_info(operator_info);
  455. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  456. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  457. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  458. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  459. // Needed by rec_parser
  460. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  461. } else {
  462. // Two CNODEs' UniqueIds should not be equal
  463. MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
  464. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  465. << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
  466. }
  467. }
  468. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  469. return SUCCESS;
  470. }
  471. // Using CNode's UniqueIdThroughCopys to construct nodes
  472. Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  473. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  474. entire_costgraph = std::make_shared<CostGraph>();
  475. entire_costgraph->SetDeviceMemoryAndCostParameter();
  476. // The map from CNode's UniqueIdThroughCopy to its operatorInfo
  477. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  478. // extract strategy from checkpoint for multi-train
  479. StrategyMap stra_map;
  480. if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
  481. if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
  482. MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
  483. }
  484. }
  485. for (auto &node : all_nodes) {
  486. // NOTE: we only care about splittable Primitive operators
  487. auto cnode = node->cast<CNodePtr>();
  488. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  489. if (bool_result) {
  490. continue;
  491. }
  492. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  493. if (!IsAutoParallelCareNode(cnode)) {
  494. continue;
  495. }
  496. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  497. // Find the operatorInfo if it exists
  498. auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
  499. if (search_cnode == from_cnode_to_info.end()) {
  500. // In this case, the corresponding OperatorInfo is not created, create the new one.
  501. auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
  502. if (operator_info == nullptr) {
  503. return FAILED;
  504. }
  505. // Needed by rec_parser
  506. operator_info->set_type(prim->name());
  507. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  508. entire_costgraph->AddOperator(operator_info);
  509. (void)cnode->set_operator_info(operator_info);
  510. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  511. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  512. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  513. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  514. // Needed by rec_parser
  515. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  516. } else {
  517. auto current_op_ptr = search_cnode->second;
  518. if (current_op_ptr == nullptr) {
  519. MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
  520. } else {
  521. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  522. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  523. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  524. if (is_find_wrong) {
  525. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  526. << " does not match the Prim: " << prim->name();
  527. }
  528. (void)cnode->set_operator_info(current_op_ptr);
  529. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  530. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  531. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  532. }
  533. }
  534. }
  535. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  536. return SUCCESS;
  537. }
  538. void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
  539. // Step 2
  540. MS_LOG(INFO) << "Constructing edges for cost graph begins.";
  541. for (auto &node : all_nodes) {
  542. auto cnode = node->cast<CNodePtr>();
  543. bool bool_result_cnode = (cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0));
  544. if (bool_result_cnode) {
  545. continue;
  546. }
  547. auto &inputs = cnode->inputs();
  548. ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
  549. if (!IsAutoParallelCareNode(cnode)) {
  550. continue;
  551. }
  552. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  553. size_t edge_count = 0;
  554. for (size_t i = 1; i < inputs.size(); ++i) {
  555. auto prev_cnode = inputs[i]->cast<CNodePtr>();
  556. bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  557. if (bool_result_prev_cnode) {
  558. continue;
  559. }
  560. ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  561. PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  562. size_t output_index = 0;
  563. bool bool_result =
  564. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  565. while (bool_result) {
  566. if (IsAutoParallelCareNode(prev_cnode)) {
  567. std::string edge_name =
  568. prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
  569. // If the edge between these two operators already has been added, then the edge will not be added again.
  570. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
  571. break;
  572. }
  573. EdgePtr edge_ptr;
  574. MS_LOG(INFO) << "Creating edge: " << edge_name;
  575. bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
  576. (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
  577. if (follow_strategy) {
  578. // Redistribution in not allowed on the edge.
  579. // Elementwise operators have the same strategy as their previous operators.
  580. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  581. output_index, i - 1, false, true);
  582. } else {
  583. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  584. output_index, i - 1, false);
  585. }
  586. // Init costs for this edge
  587. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  588. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  589. }
  590. cnode->operator_info()->AddPrevEdge(edge_ptr);
  591. prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
  592. entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
  593. MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
  594. << cnode->operator_info()->name();
  595. edge_count++;
  596. break;
  597. } else if (prev_prim->name() == TUPLE_GETITEM) {
  598. // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
  599. // this 'tuple_getitem'
  600. MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
  601. output_index = IntToSize(GetValue<int>(GetValueNode(prev_cnode->input(2))));
  602. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  603. bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  604. if (bool_result_tuple) {
  605. break;
  606. }
  607. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  608. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  609. if (!IsAutoParallelCareNode(prev_cnode)) {
  610. MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
  611. }
  612. MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
  613. << "and creating an edge between the Operator before "
  614. << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
  615. } else if (prev_prim->name() == DEPEND) {
  616. // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
  617. // this 'depend'
  618. MS_LOG(INFO) << "Jumping the 'depend' operator.";
  619. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  620. bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  621. if (bool_result_depend) {
  622. break;
  623. }
  624. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  625. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  626. MS_LOG(INFO) << "Jumped the 'depend' operator, "
  627. << "and creating an edge between the Operator before "
  628. << "'depend' and the Operator after 'depend'.";
  629. }
  630. bool_result =
  631. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  632. }
  633. }
  634. MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
  635. }
  636. MS_LOG(INFO) << "Constructing edges for cost graph ends.";
  637. }
  638. std::pair<AnfNodePtr, std::vector<AnfNodePtr>> CNodeWithRefKeys(const AnfNodePtr &cnode) {
  639. MS_EXCEPTION_IF_NULL(cnode);
  640. std::vector<AnfNodePtr> refkeys;
  641. if (cnode->isa<CNode>()) {
  642. auto cnode_ptr = cnode->cast<CNodePtr>();
  643. auto inputs = cnode_ptr->inputs();
  644. for (auto &one_input : inputs) {
  645. if (IsValueNode<RefKey>(one_input)) {
  646. refkeys.push_back(one_input);
  647. }
  648. }
  649. if (refkeys.size() >= 1) {
  650. return std::make_pair(cnode, refkeys);
  651. }
  652. }
  653. return {nullptr, refkeys};
  654. }
  655. void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
  656. // Step 3
  657. for (auto &node : all_nodes) {
  658. auto cnode_with_refkeys = CNodeWithRefKeys(node);
  659. if ((!node->isa<Parameter>()) && (cnode_with_refkeys.first == nullptr)) {
  660. continue;
  661. }
  662. std::string parameter_name;
  663. AnfNodePtr target_parameter = nullptr;
  664. AnfNodeIndexSet target_set;
  665. if (cnode_with_refkeys.first != nullptr) {
  666. // Dealing with the RefKey case
  667. auto refkeys = cnode_with_refkeys.second;
  668. auto cnode = cnode_with_refkeys.first;
  669. auto cnode_ptr = cnode->cast<CNodePtr>();
  670. if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
  671. continue;
  672. }
  673. if (!IsAutoParallelCareNode(cnode_ptr)) {
  674. continue;
  675. }
  676. if (refkeys.size() > 1) {
  677. MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
  678. }
  679. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  680. auto cnode_func_graph = cnode->func_graph();
  681. MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
  682. // Find the RefKey being used
  683. auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
  684. for (auto &candidate : candidate_set_by_refkey) {
  685. auto candidate_node = candidate.first;
  686. auto c = candidate_node->cast<CNodePtr>();
  687. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  688. continue;
  689. }
  690. if (!IsAutoParallelCareNode(c)) {
  691. continue;
  692. }
  693. target_set.add(candidate);
  694. }
  695. // Find the corresponding Parameter being used
  696. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
  697. if (parameters.size() != 1) {
  698. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  699. }
  700. parameter_name = parameters[0]->cast<ParameterPtr>()->name();
  701. target_parameter = parameters[0];
  702. auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
  703. for (auto &candidate : candidate_set_by_para) {
  704. auto candidate_node = candidate.first;
  705. auto c = candidate_node->cast<CNodePtr>();
  706. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  707. continue;
  708. }
  709. if (!IsAutoParallelCareNode(c)) {
  710. continue;
  711. }
  712. (void)target_set.insert(candidate);
  713. }
  714. } else if (node->isa<Parameter>()) {
  715. // Dealing with the Parameter case
  716. MS_EXCEPTION_IF_NULL(node->func_graph());
  717. MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
  718. auto candidate_set = node->func_graph()->manager()->node_users()[node];
  719. for (auto &candidate : candidate_set) {
  720. auto candidate_node = candidate.first;
  721. auto c = candidate_node->cast<CNodePtr>();
  722. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  723. continue;
  724. }
  725. if (!IsAutoParallelCareNode(c)) {
  726. continue;
  727. }
  728. (void)target_set.insert(candidate);
  729. }
  730. // In this case, node is a Parameter
  731. parameter_name = node->cast<ParameterPtr>()->name();
  732. target_parameter = node;
  733. }
  734. if (target_set.size() <= 1) {
  735. continue;
  736. }
  737. // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
  738. std::set<std::string> target_without_duplicate;
  739. for (auto &target : target_set) {
  740. auto target_cnode = target.first->cast<CNodePtr>();
  741. auto input_index = target.second;
  742. (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
  743. }
  744. if (target_without_duplicate.size() <= 1) {
  745. continue;
  746. }
  747. // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
  748. OperatorInfoPtr tmp_identity_ptr;
  749. bool new_identity = false;
  750. std::string tmp_identity_name;
  751. auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
  752. if (returned_identity != nullptr) {
  753. // In this case, the TmpIdentityInfo instance has already been created
  754. new_identity = false;
  755. tmp_identity_ptr = returned_identity;
  756. tmp_identity_name = tmp_identity_ptr->name();
  757. } else {
  758. // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
  759. new_identity = true;
  760. // 1) extract input shape from this Parameter
  761. MS_EXCEPTION_IF_NULL(target_parameter);
  762. AbstractBasePtr abstract = target_parameter->abstract();
  763. if (abstract == nullptr) {
  764. MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
  765. }
  766. auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
  767. if (input_shape == nullptr) {
  768. MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
  769. }
  770. std::vector<int> shape_int = input_shape->shape();
  771. Shape shape;
  772. (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
  773. [](int sub_shape) { return static_cast<int32_t>(sub_shape); });
  774. Shapes inputs_shape = {shape};
  775. Shapes outputs_shape = {shape};
  776. // 2) init the attr
  777. std::unordered_map<std::string, ValuePtr> attr = {};
  778. // Create the TmpIdentity instance
  779. tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
  780. tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
  781. TOTAL_OPS++;
  782. tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
  783. // Set the parameter and type lengths for inputs and outputs
  784. std::vector<bool> is_parameter;
  785. auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
  786. MS_EXCEPTION_IF_NULL(casted_target_parameter);
  787. if (casted_target_parameter->has_default()) {
  788. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(casted_target_parameter->default_param());
  789. bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
  790. is_parameter.push_back(require_grad);
  791. } else {
  792. is_parameter.push_back(false);
  793. }
  794. if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
  795. MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
  796. }
  797. auto node_type = target_parameter->Type();
  798. if (node_type->isa<mindspore::TensorType>()) {
  799. auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
  800. std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
  801. if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
  802. MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
  803. }
  804. } else {
  805. MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
  806. }
  807. // Generate strategies for this TmpIdentityInfo instance;
  808. if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
  809. MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
  810. }
  811. }
  812. // A flag recording whether new edges have been created or not
  813. bool add_identity_edge = false;
  814. // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
  815. for (auto &target : target_set) {
  816. auto target_cnode = target.first->cast<CNodePtr>();
  817. auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
  818. auto input_index = target.second;
  819. std::string edge_name =
  820. std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
  821. // If the edge between these two operators already has been added, then the edge will not be added again.
  822. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
  823. continue;
  824. }
  825. std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
  826. edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
  827. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  828. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  829. }
  830. target_cnode->operator_info()->AddPrevEdge(edge_ptr);
  831. tmp_identity_ptr->AddSuccEdge(edge_ptr);
  832. entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
  833. MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
  834. << target_cnode->operator_info()->name();
  835. add_identity_edge = true;
  836. }
  837. if (new_identity && add_identity_edge) {
  838. // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
  839. entire_costgraph->AddOperator(tmp_identity_ptr);
  840. }
  841. }
  842. }
  843. bool FindReshape(const CNodePtr &cnode) {
  844. if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
  845. return false;
  846. }
  847. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  848. if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
  849. return false;
  850. }
  851. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  852. MS_EXCEPTION_IF_NULL(prim);
  853. OperatorInfoPtr operator_info = cnode->operator_info();
  854. if (operator_info == nullptr) {
  855. MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
  856. }
  857. if (prim->name() != RESHAPE) {
  858. return false;
  859. }
  860. return true;
  861. }
  862. // find previous node, then obtain its strategy_cost_ vector to get its layout vector.
  863. bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) {
  864. // if previous node is a parameter, handle it in the outsize.
  865. if (node->isa<Parameter>()) {
  866. return false;
  867. }
  868. if (!node->isa<CNode>()) {
  869. return false;
  870. }
  871. CNodePtr cnode = node->cast<CNodePtr>();
  872. if (!IsValueNode<Primitive>(cnode->input(0))) {
  873. return false;
  874. }
  875. if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
  876. *pre_operator_info = cnode->operator_info();
  877. *out_index = 0;
  878. return true;
  879. }
  880. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  881. PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
  882. if (prim->name() == TUPLE_GETITEM) {
  883. *out_index = GetTupleGetItemIndex(cnode);
  884. // find tuple_get_item's previous node
  885. auto pre_node = cnode->input(1);
  886. if (!pre_node->isa<CNode>()) {
  887. MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
  888. }
  889. CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
  890. if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
  891. *pre_operator_info = pre_cnode->operator_info();
  892. return true;
  893. }
  894. return false;
  895. }
  896. for (size_t index = 0; index < cnode->inputs().size(); ++index) {
  897. if (prim->name() == DEPEND && index != 1) {
  898. continue;
  899. }
  900. if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
  901. continue;
  902. }
  903. return true;
  904. }
  905. MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
  906. return false;
  907. }
  908. // find next node, then obtain its strategy_cost_ vector to get its layout vector.
  909. // if reshape's output connect to several primitive, return the first layout found
  910. bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) {
  911. MS_EXCEPTION_IF_NULL(cnode);
  912. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  913. FuncGraphManagerPtr manager = cnode->func_graph()->manager();
  914. MS_EXCEPTION_IF_NULL(manager);
  915. AnfNodeIndexSet node_set = manager->node_users()[cnode];
  916. for (auto &node_pair : node_set) {
  917. CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
  918. if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
  919. continue;
  920. }
  921. ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
  922. MS_EXCEPTION_IF_NULL(prim_anf_node);
  923. PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
  924. MS_EXCEPTION_IF_NULL(node_prim);
  925. MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
  926. if (node_prim->name() == DEPEND && node_pair.second != 1) {
  927. continue;
  928. }
  929. if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
  930. MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
  931. *next_operator_info = use_apply->operator_info();
  932. *in_index = node_pair.second - 1;
  933. return true;
  934. }
  935. MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
  936. << " " << (use_apply->operator_info() != nullptr);
  937. if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
  938. return true;
  939. }
  940. }
  941. return false;
  942. }
  943. void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
  944. for (auto node : all_nodes) {
  945. auto cnode = node->cast<CNodePtr>();
  946. if (!FindReshape(cnode)) {
  947. continue;
  948. }
  949. MS_ASSERT(cnode->inputs().size() == 3);
  950. // get previous node's strategy_cost_
  951. auto pre_node = cnode->input(1);
  952. int32_t out_index = 0;
  953. OperatorInfoPtr pre_operator_info;
  954. std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
  955. if (pre_node->isa<Parameter>()) {
  956. OperatorInfoPtr operator_info = cnode->operator_info();
  957. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  958. reshape_info->SetCostForReshapeWithParameter();
  959. pre_operator_info = reshape_info;
  960. pre_stra_costs = reshape_info->strategy_cost();
  961. } else {
  962. if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
  963. MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
  964. }
  965. pre_stra_costs = pre_operator_info->strategy_cost();
  966. }
  967. // get next node's strategy_cost_
  968. int32_t in_index = 0;
  969. OperatorInfoPtr next_operator_info;
  970. std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
  971. bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
  972. if (!find_next_node) {
  973. MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
  974. }
  975. // set input_layout and output_layout for reshape.
  976. // init reshape and set cost for each input_layout and output_layout.
  977. OperatorInfoPtr operator_info = cnode->operator_info();
  978. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
  979. reshape_info->set_pre_operator_name(pre_operator_info->name());
  980. reshape_info->set_pre_operator_index(out_index);
  981. if (find_next_node) {
  982. next_stra_costs = next_operator_info->strategy_cost();
  983. reshape_info->set_next_operator_name(next_operator_info->name());
  984. reshape_info->set_next_operator_index(in_index);
  985. }
  986. bool is_prev_param = pre_node->isa<Parameter>();
  987. if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
  988. SUCCESS) {
  989. MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!";
  990. }
  991. }
  992. }
  993. Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  994. // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
  995. // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
  996. // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
  997. // for each OperatorInfo;
  998. // Step 1.1: Deal with 'Reshape':
  999. // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
  1000. // layout as its output layout.
  1001. // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
  1002. // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
  1003. // for each edge, based on the strategies of two OperatorInfos;
  1004. // Step 3: Augment the costgraph:
  1005. // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
  1006. // operator for this Parameter, and add an edge for the use of this Parameter by each
  1007. // subsequent operator;
  1008. // Step 3.1: Calculate memory usage:
  1009. // note the memory usage calculation is different in training phase and inference phase.
  1010. // Step 4: Run the Dynamic Programming algorithm:
  1011. // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
  1012. // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
  1013. // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
  1014. // runs on each of them.
  1015. //
  1016. // OUTPUT: the determined strategy for each operator.
  1017. // Step 1
  1018. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  1019. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  1020. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  1021. << entire_costgraph->GetOperators().size() << " operators.";
  1022. } else {
  1023. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  1024. }
  1025. } else {
  1026. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  1027. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  1028. << entire_costgraph->GetOperators().size() << " operators.";
  1029. } else {
  1030. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  1031. }
  1032. }
  1033. // Step 1.1
  1034. ReshapeCostCompute(all_nodes);
  1035. // Step 2
  1036. ConstructCostGraphEdges(all_nodes);
  1037. MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  1038. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  1039. // Step 3: Augment the costgraph.
  1040. AugmentCostGraph(all_nodes);
  1041. MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
  1042. << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
  1043. // Step 3.1: Calculate the memory usage
  1044. if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
  1045. MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
  1046. }
  1047. // Step 4: run DP algorithm on the costgraph.
  1048. if (GetStrategy(entire_costgraph) != SUCCESS) {
  1049. MS_LOG(ERROR) << "Strategy search for cost-graph fails";
  1050. return FAILED;
  1051. }
  1052. MS_LOG(INFO) << "Searching strategy succeeded.";
  1053. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  1054. MS_LOG(INFO) << "Init selected strategy succeeded.";
  1055. } else {
  1056. MS_LOG(EXCEPTION) << "Init selected strategy failed.";
  1057. }
  1058. // print the selected strategy
  1059. for (auto &op : entire_costgraph->GetOperators()) {
  1060. StrategyPtr s_strategy = op->selected_strategy();
  1061. MS_LOG(INFO) << op->name() << " : The strategy is:";
  1062. PrintStrategy(s_strategy);
  1063. }
  1064. return SUCCESS;
  1065. }
  1066. std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
  1067. std::vector<std::vector<std::string>> input_tensor_names) {
  1068. for (size_t j = 0; j < input_tensor_names.size(); j++) {
  1069. for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
  1070. if (it->first == input_tensor_names[j][k]) {
  1071. input_tensor_names[j][k] = it->second;
  1072. break;
  1073. }
  1074. }
  1075. }
  1076. return input_tensor_names;
  1077. }
  1078. Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  1079. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  1080. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  1081. << " operators.";
  1082. } else {
  1083. MS_LOG(ERROR) << "Constructing nodes for cost graph failed.";
  1084. return FAILED;
  1085. }
  1086. auto ops = entire_costgraph->GetOperators();
  1087. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  1088. auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
  1089. for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
  1090. input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
  1091. }
  1092. std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
  1093. std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
  1094. std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
  1095. graph = EliminateGraph(graph, eli_list, index_list);
  1096. size_t num_device = g_device_manager->DeviceNum();
  1097. double device_memory = entire_costgraph->GetDeviceMemory();
  1098. if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
  1099. MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
  1100. } else {
  1101. MS_LOG(ERROR) << "PartitionForAllDevices failed.";
  1102. return FAILED;
  1103. }
  1104. GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list);
  1105. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  1106. MS_LOG(INFO) << "Init selected strategy succeeded.";
  1107. } else {
  1108. MS_LOG(ERROR) << "Init selected strategy failed.";
  1109. return FAILED;
  1110. }
  1111. // print the selected strategy
  1112. for (auto &op : entire_costgraph->GetOperators()) {
  1113. StrategyPtr s_strategy = op->selected_strategy();
  1114. MS_LOG(INFO) << op->name() << " : The strategy is:";
  1115. PrintStrategy(s_strategy);
  1116. }
  1117. return SUCCESS;
  1118. }
  1119. } // namespace parallel
  1120. } // namespace mindspore