You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel.cc 45 kB

6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/step_auto_parallel.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "ir/anf.h"
  28. #include "ir/meta_tensor.h"
  29. #include "optimizer/opt.h"
  30. #include "optimizer/optimizer.h"
  31. #include "parallel/auto_parallel/dp_algo_costmodel.h"
  32. #include "parallel/auto_parallel/edge_costmodel.h"
  33. #include "parallel/auto_parallel/graph_costmodel.h"
  34. #include "parallel/auto_parallel/rec_core/rec_generate_strategy.h"
  35. #include "parallel/auto_parallel/rec_core/rec_parse_graph.h"
  36. #include "parallel/auto_parallel/rec_core/rec_partition.h"
  37. #include "parallel/context.h"
  38. #include "parallel/ops_info/tmp_identity_info.h"
  39. #include "parallel/step_parallel.h"
  40. #include "pipeline/parse/python_adapter.h"
  41. #include "pipeline/pipeline.h"
  42. namespace mindspore {
  43. namespace parallel {
  44. // splittable_op_ will continuously be updated
  45. std::vector<std::string> splittable_op_ = {MATMUL,
  46. GELU,
  47. TANH,
  48. SOFTMAX,
  49. LOG_SOFTMAX,
  50. ACTIVATION,
  51. PRELU,
  52. FLOORDIV,
  53. L2_NORMALIZE,
  54. TRANSPOSE,
  55. RESHAPE,
  56. TENSOR_ADD,
  57. SUB,
  58. MUL,
  59. DIV,
  60. GREATER,
  61. MAXPOOL,
  62. MAXPOOLV2,
  63. VIRTUAL_DATA_SET,
  64. SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
  65. RELU,
  66. ONEHOT,
  67. DROPOUT_DO_MASK,
  68. REDUCE_MAX,
  69. REDUCE_MIN,
  70. ARGMAXWITHVALUE,
  71. ARGMINWITHVALUE,
  72. REDUCE_SUM,
  73. CONV2D,
  74. FUSE_BATCH_NORM,
  75. POOLING,
  76. SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
  77. SIGMOID_CROSS_ENTROPY_WITH_LOGITS,
  78. MAX_POOL_WITH_ARGMAX,
  79. SIMPLE_MEAN,
  80. FLATTEN,
  81. BATCH_NORM,
  82. LAYER_NORM,
  83. BIAS_ADD,
  84. ASSIGN_SUB,
  85. COS,
  86. ACOS,
  87. EXP,
  88. LOG,
  89. REDUCE_MEAN,
  90. REAL_DIV,
  91. SIGMOID,
  92. POW,
  93. MAXIMUM,
  94. MINIMUM,
  95. EQUAL,
  96. NOT_EQUAL,
  97. LOGICALNOT,
  98. GATHERV2,
  99. STRIDEDSLICE,
  100. SQRT,
  101. GET_NEXT,
  102. CAST,
  103. NEG,
  104. SQUARE,
  105. BATCH_MATMUL,
  106. EXPAND_DIMS,
  107. SQUEEZE};
  108. std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, CAST,
  109. POW, EXP, LOG, COS, ACOS, LOGICALNOT, NEG, SQUARE};
  110. bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
  111. MS_EXCEPTION_IF_NULL(root);
  112. MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
  113. std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
  114. // assume no change to graph
  115. bool changes = false;
  116. // control whether use model_parallel mode
  117. if ((parallel_mode != AUTO_PARALLEL) || root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY]) {
  118. return changes;
  119. }
  120. // check whether strategy_search_mode is valid
  121. std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
  122. if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
  123. // Setting searching mode: dynanic programming as default.
  124. strategy_search_mode = DYNAMIC_PROGRAMMING;
  125. MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
  126. }
  127. struct timeval start_time, end_time;
  128. (void)gettimeofday(&start_time, nullptr);
  129. if (MsContext::GetInstance()->save_graphs_flag()) {
  130. draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
  131. }
  132. MS_LOG(INFO) << "Now entering step auto parallel";
  133. TOTAL_OPS = 0;
  134. AnfNodePtr ret = root->get_return();
  135. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  136. if (ParallelInit() != SUCCESS) {
  137. MS_LOG(EXCEPTION) << "Parallel init failed";
  138. }
  139. // mark the forward cnodes, parallel only care these nodes
  140. MarkForwardCNode(root);
  141. if (FindCommunicationOp(all_nodes)) {
  142. MS_LOG(EXCEPTION) << "The graph contain communication op";
  143. }
  144. // search parallelization strategy
  145. if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
  146. if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
  147. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
  148. }
  149. } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
  150. if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
  151. MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
  152. }
  153. } else {
  154. MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
  155. }
  156. (void)gettimeofday(&end_time, nullptr);
  157. uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  158. time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  159. MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
  160. root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true;
  161. return changes;
  162. }
  163. // Given the node, return whether each input is a parameter or a output of a operator.
  164. // The returned boolean vector should be the same order of the inputs, thus its implementation
  165. // is closely consistent with ExtractShape() in step_parallel.cc
  166. std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
  167. std::vector<bool> is_parameter;
  168. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  169. for (size_t i = 1; i < node_inputs.size(); ++i) {
  170. auto input = node_inputs[i];
  171. if (input->isa<Parameter>()) {
  172. auto input_parameter = input->cast<ParameterPtr>();
  173. if (input_parameter->has_default()) {
  174. bool require_grad =
  175. py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
  176. is_parameter.push_back(require_grad);
  177. } else {
  178. is_parameter.push_back(false);
  179. }
  180. } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
  181. is_parameter.push_back(false);
  182. }
  183. }
  184. return is_parameter;
  185. }
  186. // Given the type, return the number of bytes to represent this type
  187. size_t GetLengthOfDataType(const TypePtr &type) {
  188. switch (type->type_id()) {
  189. case kNumberTypeBool:
  190. return sizeof(bool);
  191. case kNumberTypeInt8:
  192. return sizeof(int8_t);
  193. case kNumberTypeInt16:
  194. return sizeof(int16_t);
  195. case kNumberTypeInt32:
  196. return sizeof(int32_t);
  197. case kNumberTypeInt64:
  198. return sizeof(int64_t);
  199. case kNumberTypeUInt8:
  200. return sizeof(uint8_t);
  201. case kNumberTypeUInt16:
  202. return sizeof(uint16_t);
  203. case kNumberTypeUInt32:
  204. return sizeof(uint32_t);
  205. case kNumberTypeUInt64:
  206. return sizeof(uint64_t);
  207. case kNumberTypeFloat16:
  208. return sizeof(float) / 2;
  209. case kNumberTypeFloat32:
  210. return sizeof(float);
  211. case kNumberTypeFloat64:
  212. return sizeof(double);
  213. case kNumberTypeInt:
  214. return sizeof(int);
  215. case kNumberTypeUInt:
  216. return sizeof(unsigned int);
  217. case kNumberTypeFloat:
  218. return sizeof(float);
  219. default:
  220. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  221. }
  222. }
  223. size_t GetInputsTypeLen(const AnfNodePtr &input) {
  224. MS_EXCEPTION_IF_NULL(input);
  225. if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
  226. MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
  227. }
  228. size_t input_type_len = 0;
  229. auto type = input->Type();
  230. MS_EXCEPTION_IF_NULL(type);
  231. if (type->isa<mindspore::TensorType>()) {
  232. auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
  233. input_type_len = GetLengthOfDataType(input_element_type);
  234. } else {
  235. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  236. }
  237. return input_type_len;
  238. }
  239. std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
  240. MS_EXCEPTION_IF_NULL(node);
  241. std::vector<size_t> inputs_type_len;
  242. std::vector<AnfNodePtr> node_inputs{node->inputs()};
  243. // extract input element length
  244. for (auto &input : node_inputs) {
  245. if (IsValueNode<RefKey>(input)) {
  246. auto func_graph = node->func_graph();
  247. MS_EXCEPTION_IF_NULL(func_graph);
  248. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
  249. if (parameters.size() != 1) {
  250. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  251. }
  252. inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
  253. } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
  254. // extract input shape from parameter and apply node
  255. inputs_type_len.push_back(GetInputsTypeLen(input));
  256. }
  257. }
  258. return inputs_type_len;
  259. }
  260. std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
  261. MS_EXCEPTION_IF_NULL(node);
  262. std::vector<TypePtr> outputs_type;
  263. // extract output element type
  264. auto primary_output_type = node->Type();
  265. MS_EXCEPTION_IF_NULL(primary_output_type);
  266. if (primary_output_type->isa<mindspore::Tuple>()) {
  267. // in this case, the output is a tuple
  268. auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
  269. auto elements = tuple_output_type->elements();
  270. for (auto &ele : elements) {
  271. if (ele->isa<mindspore::TensorType>()) {
  272. auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
  273. outputs_type.push_back(ele_element_type);
  274. } else {
  275. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  276. }
  277. }
  278. } else {
  279. // in this case, the output is a single tensor
  280. if (primary_output_type->isa<mindspore::TensorType>()) {
  281. auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
  282. outputs_type.push_back(element_type);
  283. } else {
  284. MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
  285. }
  286. }
  287. return outputs_type;
  288. }
  289. bool IsElementWiseOperator(const std::string &op_name) {
  290. auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name);
  291. return (iter != elementwise_op_.end());
  292. }
  293. bool IsSplittableOperator(const std::string &op_name) {
  294. std::vector<std::string>::iterator iter;
  295. iter = std::find(splittable_op_.begin(), splittable_op_.end(), op_name);
  296. return (iter != splittable_op_.end());
  297. }
  298. bool IsAutoParallelCareNode(const CNodePtr &cnode) {
  299. MS_EXCEPTION_IF_NULL(cnode);
  300. ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
  301. if (prim_node == nullptr) {
  302. return false;
  303. }
  304. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
  305. if (prim == nullptr) {
  306. return false;
  307. }
  308. bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
  309. if (bool_result) {
  310. MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
  311. } else if (prim->name() == CAST) {
  312. return true;
  313. }
  314. return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
  315. }
  316. OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
  317. MS_EXCEPTION_IF_NULL(prim);
  318. MS_EXCEPTION_IF_NULL(cnode);
  319. auto attrs = prim->attrs();
  320. std::vector<Shapes> shape_list = ExtractShape(cnode);
  321. if (shape_list.empty()) {
  322. MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
  323. }
  324. // Create an OperatorInfo instance
  325. OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
  326. MS_EXCEPTION_IF_NULL(operator_info);
  327. // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
  328. std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
  329. if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
  330. MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
  331. return nullptr;
  332. }
  333. // Set the data type for inputs and outputs of this OperatorInfo
  334. auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
  335. auto outputs_type = ExtractOutputTypeByNode(cnode);
  336. std::vector<size_t> outputs_type_length;
  337. outputs_type_length.reserve(outputs_type.size());
  338. std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
  339. GetLengthOfDataType);
  340. if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
  341. MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
  342. return nullptr;
  343. }
  344. if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
  345. MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
  346. return nullptr;
  347. }
  348. // When the 'inputs' contains numerical values for some operators, these values should be extracted from
  349. // ANF graph
  350. auto &inputs = cnode->inputs();
  351. std::vector<ValuePtr> input_value;
  352. for (size_t index = 1; index < inputs.size(); ++index) {
  353. if (inputs[index]->isa<ValueNode>()) {
  354. input_value.push_back(GetValueNode(inputs[index]));
  355. } else {
  356. input_value.emplace_back(nullptr);
  357. }
  358. }
  359. operator_info->set_input_value(input_value);
  360. operator_info->set_outputs_dtype(cnode->Type());
  361. operator_info->set_cnode(cnode);
  362. // If no strategy has been configured for this operator, then candidate strategies are generated for
  363. // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy
  364. if (!StrategyFound(attrs) || prim->name() == CAST) {
  365. // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
  366. // BatchParallelInfo operator
  367. operator_info->ComputeBatchSplitFlagList();
  368. if (operator_info->GenerateStrategies(0) != SUCCESS) {
  369. MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
  370. return nullptr;
  371. }
  372. } else {
  373. // In this case, the configured strategy should be extracted to help setting cost
  374. StrategyPtr strategyPtr = parallel::ExtractStrategy(attrs);
  375. if (strategyPtr != nullptr) {
  376. if (prim->name() == RESHAPE) {
  377. MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
  378. }
  379. // Set cost for this configured strategy
  380. if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
  381. MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
  382. } else if (FULLY_USE_DEVICES) {
  383. // If configured to fully use devices, then checking for the user-specified strategy
  384. int32_t used_devices = operator_info->used_devices();
  385. MS_EXCEPTION_IF_NULL(g_device_manager);
  386. auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
  387. // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
  388. if (used_devices == 1) {
  389. return operator_info;
  390. }
  391. // 'used_devices == -1' means that 'used_devices_' is not set
  392. if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) {
  393. MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
  394. << "but the specified strategy uses device: " << used_devices
  395. << ", total devices: " << total_device_num;
  396. }
  397. }
  398. }
  399. }
  400. return operator_info;
  401. }
  402. // Using CNode's UniqueIds to construct nodes
  403. Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  404. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  405. entire_costgraph = std::make_shared<CostGraph>();
  406. entire_costgraph->SetDeviceMemoryAndCostParameter();
  407. // The map from CNode's UniqueId to its operatorInfo
  408. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  409. // Step 1
  410. for (auto &node : all_nodes) {
  411. // NOTE: we only care about splittable Primitive operators
  412. auto cnode = node->cast<CNodePtr>();
  413. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  414. if (bool_result) {
  415. continue;
  416. }
  417. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  418. if (!IsAutoParallelCareNode(cnode)) {
  419. continue;
  420. }
  421. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  422. MS_EXCEPTION_IF_NULL(prim);
  423. auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
  424. if (search_cnode == from_cnode_to_info.end()) {
  425. auto operator_info = CreateTheOperatorInfo(prim, cnode);
  426. if (operator_info == nullptr) {
  427. return FAILED;
  428. }
  429. // Needed by rec_parser
  430. operator_info->set_type(prim->name());
  431. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  432. entire_costgraph->AddOperator(operator_info);
  433. (void)cnode->set_operator_info(operator_info);
  434. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  435. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  436. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  437. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  438. // Needed by rec_parser
  439. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  440. } else {
  441. // Two CNODEs' UniqueIds should not be equal
  442. MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
  443. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  444. << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
  445. }
  446. }
  447. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  448. return SUCCESS;
  449. }
  450. // Using CNode's UniqueIdThroughCopys to construct nodes
  451. Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
  452. MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
  453. entire_costgraph = std::make_shared<CostGraph>();
  454. entire_costgraph->SetDeviceMemoryAndCostParameter();
  455. // The map from CNode's UniqueIdThroughCopy to its operatorInfo
  456. std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
  457. for (auto &node : all_nodes) {
  458. // NOTE: we only care about splittable Primitive operators
  459. auto cnode = node->cast<CNodePtr>();
  460. bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
  461. if (bool_result) {
  462. continue;
  463. }
  464. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  465. if (!IsAutoParallelCareNode(cnode)) {
  466. continue;
  467. }
  468. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  469. // Find the operatorInfo if it exists
  470. auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
  471. if (search_cnode == from_cnode_to_info.end()) {
  472. // In this case, the corresponding OperatorInfo is not created, create the new one.
  473. auto operator_info = CreateTheOperatorInfo(prim, cnode);
  474. if (operator_info == nullptr) {
  475. return FAILED;
  476. }
  477. // Needed by rec_parser
  478. operator_info->set_type(prim->name());
  479. std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
  480. entire_costgraph->AddOperator(operator_info);
  481. (void)cnode->set_operator_info(operator_info);
  482. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  483. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  484. << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
  485. (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
  486. // Needed by rec_parser
  487. entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
  488. } else {
  489. auto current_op_ptr = search_cnode->second;
  490. if (current_op_ptr == nullptr) {
  491. MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
  492. } else {
  493. bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
  494. (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
  495. (current_op_ptr->name().find(prim->name()) == std::string::npos);
  496. if (is_find_wrong) {
  497. MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
  498. << " does not match the Prim: " << prim->name();
  499. }
  500. (void)cnode->set_operator_info(current_op_ptr);
  501. MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
  502. << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
  503. << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
  504. }
  505. }
  506. }
  507. MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
  508. return SUCCESS;
  509. }
  510. void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
  511. // Step 2
  512. MS_LOG(INFO) << "Constructing edges for cost graph begins.";
  513. for (auto &node : all_nodes) {
  514. auto cnode = node->cast<CNodePtr>();
  515. bool bool_result_cnode = (cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0));
  516. if (bool_result_cnode) {
  517. continue;
  518. }
  519. auto &inputs = cnode->inputs();
  520. ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
  521. if (!IsAutoParallelCareNode(cnode)) {
  522. continue;
  523. }
  524. PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
  525. size_t edge_count = 0;
  526. for (size_t i = 1; i < inputs.size(); ++i) {
  527. auto prev_cnode = inputs[i]->cast<CNodePtr>();
  528. bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  529. if (bool_result_prev_cnode) {
  530. continue;
  531. }
  532. ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  533. PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  534. size_t output_index = 0;
  535. bool bool_result =
  536. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  537. while (bool_result) {
  538. if (IsAutoParallelCareNode(prev_cnode)) {
  539. std::string edge_name =
  540. prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
  541. // If the edge between these two operators already has been added, then the edge will not be added again.
  542. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
  543. break;
  544. }
  545. EdgePtr edge_ptr;
  546. MS_LOG(INFO) << "Creating edge: " << edge_name;
  547. bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name());
  548. if (follow_strategy) {
  549. // Redistribution in not allowed on the edge.
  550. // Elementwise operators have the same strategy as their previous operators.
  551. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  552. output_index, i - 1, false, true);
  553. } else {
  554. edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
  555. output_index, i - 1, false);
  556. }
  557. // Init costs for this edge
  558. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  559. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  560. }
  561. cnode->operator_info()->AddPrevEdge(edge_ptr);
  562. prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
  563. entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
  564. MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
  565. << cnode->operator_info()->name();
  566. edge_count++;
  567. break;
  568. } else if (prev_prim->name() == TUPLE_GETITEM) {
  569. // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
  570. // this 'tuple_getitem'
  571. MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
  572. output_index = IntToSize(GetValue<int>(GetValueNode(prev_cnode->input(2))));
  573. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  574. bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  575. if (bool_result_tuple) {
  576. break;
  577. }
  578. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  579. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  580. if (!IsAutoParallelCareNode(prev_cnode)) {
  581. MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
  582. }
  583. MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
  584. << "and creating an edge between the Operator before "
  585. << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
  586. } else if (prev_prim->name() == DEPEND) {
  587. // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
  588. // this 'depend'
  589. MS_LOG(INFO) << "Jumping the 'depend' operator.";
  590. prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
  591. bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
  592. if (bool_result_depend) {
  593. break;
  594. }
  595. prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
  596. prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
  597. MS_LOG(INFO) << "Jumped the 'depend' operator, "
  598. << "and creating an edge between the Operator before "
  599. << "'depend' and the Operator after 'depend'.";
  600. }
  601. bool_result =
  602. (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
  603. }
  604. }
  605. MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
  606. }
  607. MS_LOG(INFO) << "Constructing edges for cost graph ends.";
  608. }
  609. std::pair<AnfNodePtr, std::vector<AnfNodePtr>> CNodeWithRefKeys(const AnfNodePtr &cnode) {
  610. MS_EXCEPTION_IF_NULL(cnode);
  611. std::vector<AnfNodePtr> refkeys;
  612. if (cnode->isa<CNode>()) {
  613. auto cnode_ptr = cnode->cast<CNodePtr>();
  614. auto inputs = cnode_ptr->inputs();
  615. for (auto &one_input : inputs) {
  616. if (IsValueNode<RefKey>(one_input)) {
  617. refkeys.push_back(one_input);
  618. }
  619. }
  620. if (refkeys.size() >= 1) {
  621. return std::make_pair(cnode, refkeys);
  622. }
  623. }
  624. return {nullptr, refkeys};
  625. }
  626. void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
  627. // Step 3
  628. for (auto &node : all_nodes) {
  629. auto cnode_with_refkeys = CNodeWithRefKeys(node);
  630. if ((!node->isa<Parameter>()) && (cnode_with_refkeys.first == nullptr)) {
  631. continue;
  632. }
  633. std::string parameter_name;
  634. AnfNodePtr target_parameter = nullptr;
  635. AnfNodeIndexSet target_set;
  636. if (cnode_with_refkeys.first != nullptr) {
  637. // Dealing with the RefKey case
  638. auto refkeys = cnode_with_refkeys.second;
  639. auto cnode = cnode_with_refkeys.first;
  640. auto cnode_ptr = cnode->cast<CNodePtr>();
  641. if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
  642. continue;
  643. }
  644. if (!IsAutoParallelCareNode(cnode_ptr)) {
  645. continue;
  646. }
  647. if (refkeys.size() > 1) {
  648. MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
  649. }
  650. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  651. auto cnode_func_graph = cnode->func_graph();
  652. MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
  653. // Find the RefKey being used
  654. auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
  655. for (auto &candidate : candidate_set_by_refkey) {
  656. auto candidate_node = candidate.first;
  657. auto c = candidate_node->cast<CNodePtr>();
  658. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  659. continue;
  660. }
  661. if (!IsAutoParallelCareNode(c)) {
  662. continue;
  663. }
  664. target_set.add(candidate);
  665. }
  666. // Find the corresponding Parameter being used
  667. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
  668. if (parameters.size() != 1) {
  669. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  670. }
  671. parameter_name = parameters[0]->cast<ParameterPtr>()->name();
  672. target_parameter = parameters[0];
  673. auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
  674. for (auto &candidate : candidate_set_by_para) {
  675. auto candidate_node = candidate.first;
  676. auto c = candidate_node->cast<CNodePtr>();
  677. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  678. continue;
  679. }
  680. if (!IsAutoParallelCareNode(c)) {
  681. continue;
  682. }
  683. (void)target_set.insert(candidate);
  684. }
  685. } else if (node->isa<Parameter>()) {
  686. // Dealing with the Parameter case
  687. MS_EXCEPTION_IF_NULL(node->func_graph());
  688. MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
  689. auto candidate_set = node->func_graph()->manager()->node_users()[node];
  690. for (auto &candidate : candidate_set) {
  691. auto candidate_node = candidate.first;
  692. auto c = candidate_node->cast<CNodePtr>();
  693. if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) {
  694. continue;
  695. }
  696. if (!IsAutoParallelCareNode(c)) {
  697. continue;
  698. }
  699. (void)target_set.insert(candidate);
  700. }
  701. // In this case, node is a Parameter
  702. parameter_name = node->cast<ParameterPtr>()->name();
  703. target_parameter = node;
  704. }
  705. if (target_set.size() <= 1) {
  706. continue;
  707. }
  708. // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
  709. std::set<std::string> target_without_duplicate;
  710. for (auto &target : target_set) {
  711. auto target_cnode = target.first->cast<CNodePtr>();
  712. auto input_index = target.second;
  713. (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
  714. }
  715. if (target_without_duplicate.size() <= 1) {
  716. continue;
  717. }
  718. // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
  719. OperatorInfoPtr tmp_identity_ptr;
  720. bool new_identity = false;
  721. std::string tmp_identity_name;
  722. auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
  723. if (returned_identity != nullptr) {
  724. // In this case, the TmpIdentityInfo instance has already been created
  725. new_identity = false;
  726. tmp_identity_ptr = returned_identity;
  727. tmp_identity_name = tmp_identity_ptr->name();
  728. } else {
  729. // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
  730. new_identity = true;
  731. // 1) extract input shape from this Parameter
  732. MS_EXCEPTION_IF_NULL(target_parameter);
  733. AbstractBasePtr abstract = target_parameter->abstract();
  734. if (abstract == nullptr) {
  735. MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
  736. }
  737. auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
  738. if (input_shape == nullptr) {
  739. MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
  740. }
  741. std::vector<int> shape_int = input_shape->shape();
  742. Shape shape;
  743. (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
  744. [](int sub_shape) { return static_cast<int32_t>(sub_shape); });
  745. Shapes inputs_shape = {shape};
  746. Shapes outputs_shape = {shape};
  747. // 2) init the attr
  748. std::unordered_map<std::string, ValuePtr> attr = {};
  749. // Create the TmpIdentity instance
  750. tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
  751. tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
  752. TOTAL_OPS++;
  753. tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
  754. // Set the parameter and type lengths for inputs and outputs
  755. std::vector<bool> is_parameter;
  756. auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
  757. MS_EXCEPTION_IF_NULL(casted_target_parameter);
  758. if (casted_target_parameter->has_default()) {
  759. bool require_grad = py::cast<bool>(
  760. parse::python_adapter::GetPyObjAttr(casted_target_parameter->default_param(), "requires_grad"));
  761. is_parameter.push_back(require_grad);
  762. } else {
  763. is_parameter.push_back(false);
  764. }
  765. if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
  766. MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
  767. }
  768. auto node_type = target_parameter->Type();
  769. if (node_type->isa<mindspore::TensorType>()) {
  770. auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
  771. std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
  772. if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
  773. MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
  774. }
  775. } else {
  776. MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
  777. }
  778. // Generate strategies for this TmpIdentityInfo instance;
  779. if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
  780. MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
  781. }
  782. }
  783. // A flag recording whether new edges have been created or not
  784. bool add_identity_edge = false;
  785. // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
  786. for (auto &target : target_set) {
  787. auto target_cnode = target.first->cast<CNodePtr>();
  788. auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
  789. auto input_index = target.second;
  790. std::string edge_name =
  791. std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
  792. // If the edge between these two operators already has been added, then the edge will not be added again.
  793. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
  794. continue;
  795. }
  796. std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
  797. edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
  798. if (edge_ptr->InitEdgeCost() != SUCCESS) {
  799. MS_LOG(EXCEPTION) << "Edge cost initialization failed";
  800. }
  801. target_cnode->operator_info()->AddPrevEdge(edge_ptr);
  802. tmp_identity_ptr->AddSuccEdge(edge_ptr);
  803. entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
  804. MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
  805. << target_cnode->operator_info()->name();
  806. add_identity_edge = true;
  807. }
  808. if (new_identity && add_identity_edge) {
  809. // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
  810. entire_costgraph->AddOperator(tmp_identity_ptr);
  811. }
  812. }
  813. }
  814. Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  815. // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
  816. // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
  817. // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
  818. // for each OperatorInfo;
  819. // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
  820. // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
  821. // for each edge, based on the strategies of two OperatorInfos;
  822. // Step 3: Augment the costgraph:
  823. // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
  824. // operator for this Parameter, and add an edge for the use of this Parameter by each
  825. // subsequent operator;
  826. // Step 3.1: Calculate memory usage
  827. // Step 4: Run the Dynamic Programming algorithm:
  828. // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
  829. // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
  830. // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
  831. // runs on each of them.
  832. //
  833. // OUTPUT: the determined strategy for each operator.
  834. // Step 1
  835. if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
  836. if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
  837. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  838. << entire_costgraph->GetOperators().size() << " operators.";
  839. } else {
  840. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  841. }
  842. } else {
  843. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  844. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
  845. << entire_costgraph->GetOperators().size() << " operators.";
  846. } else {
  847. MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
  848. }
  849. }
  850. // Step 2
  851. ConstructCostGraphEdges(all_nodes);
  852. MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  853. << " operators, and " << entire_costgraph->GetNumPairs() << " edges.",
  854. // Step 3: Augment the costgraph.
  855. AugmentCostGraph(all_nodes);
  856. MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
  857. << " operators, and " << entire_costgraph->GetNumPairs() << " edges.";
  858. // Step 3.1: Calculate the memory usage
  859. if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
  860. // Calculate operators' memory usage
  861. if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) {
  862. MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed.";
  863. }
  864. // Calculate edges' memory usage
  865. if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) {
  866. MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed.";
  867. }
  868. // Correct memory usage caused by TmpIdentity
  869. if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) {
  870. MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed.";
  871. }
  872. } else {
  873. MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed.";
  874. }
  875. // Step 4: run DP algorithm on the costgraph.
  876. if (GetStrategy(entire_costgraph) != SUCCESS) {
  877. MS_LOG(ERROR) << "Strategy search for cost-graph fails";
  878. return FAILED;
  879. }
  880. MS_LOG(INFO) << "Searching strategy succeeded.";
  881. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  882. MS_LOG(INFO) << "Init selected strategy succeeded.";
  883. } else {
  884. MS_LOG(EXCEPTION) << "Init selected strategy failed.";
  885. }
  886. // print the selected strategy
  887. for (auto &op : entire_costgraph->GetOperators()) {
  888. StrategyPtr s_strategy = op->selected_strategy();
  889. MS_LOG(INFO) << op->name() << " : The strategy is:";
  890. PrintStrategy(s_strategy);
  891. }
  892. return SUCCESS;
  893. }
  894. std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
  895. std::vector<std::vector<std::string>> input_tensor_names) {
  896. for (size_t j = 0; j < input_tensor_names.size(); j++) {
  897. for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
  898. if (it->first == input_tensor_names[j][k]) {
  899. input_tensor_names[j][k] = it->second;
  900. break;
  901. }
  902. }
  903. }
  904. return input_tensor_names;
  905. }
  906. Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
  907. if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
  908. MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
  909. << " operators.";
  910. } else {
  911. MS_LOG(ERROR) << "Constructing nodes for cost graph failed.";
  912. return FAILED;
  913. }
  914. auto ops = entire_costgraph->GetOperators();
  915. std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
  916. auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
  917. for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
  918. input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
  919. }
  920. std::shared_ptr<std::vector<size_t>> ops_nodes_list(new std::vector<size_t>);
  921. std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
  922. size_t num_device = g_device_manager->DeviceNum();
  923. double device_memory = entire_costgraph->GetDeviceMemory();
  924. if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
  925. MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
  926. } else {
  927. MS_LOG(ERROR) << "PartitionForAllDevices failed.";
  928. return FAILED;
  929. }
  930. bool mask_special_ops = true;
  931. GenerateStrategy(graph, mask_special_ops, ops);
  932. if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
  933. MS_LOG(INFO) << "Init selected strategy succeeded.";
  934. } else {
  935. MS_LOG(ERROR) << "Init selected strategy failed.";
  936. return FAILED;
  937. }
  938. // print the selected strategy
  939. for (auto &op : entire_costgraph->GetOperators()) {
  940. StrategyPtr s_strategy = op->selected_strategy();
  941. MS_LOG(INFO) << op->name() << " : The strategy is:";
  942. PrintStrategy(s_strategy);
  943. }
  944. return SUCCESS;
  945. }
  946. } // namespace parallel
  947. } // namespace mindspore