You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_costmodel.cc 60 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/auto_parallel/graph_costmodel.h"
  17. #include <algorithm>
  18. #include <cstdlib>
  19. #include <iterator>
  20. #include <numeric>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. namespace mindspore {
  25. namespace parallel {
  26. CostGraphPtr entire_costgraph = nullptr;
  27. size_t TOTAL_OPS = 0;
  28. double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA;
  29. bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
  30. double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY;
  31. double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
  32. double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST;
  33. double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS;
  34. bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
  35. size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
  36. bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
  37. bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
  38. void CostGraph::SetDeviceMemoryAndCostParameter() {
  39. MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
  40. // DEVICE_MEMORY_CAPACITY
  41. auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
  42. if (device_memory <= 0) {
  43. MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
  44. }
  45. dev_memory_ = device_memory;
  46. DEVICE_MEMORY_CAPACITY = device_memory;
  47. MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << ".";
  48. // COST_MODEL_ALPHA
  49. auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
  50. if (alpha <= 0) {
  51. MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
  52. }
  53. costmodel_alpha_ = alpha;
  54. MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
  55. // COST_MODEL_BETA
  56. auto beta = CostModelContext::GetInstance()->costmodel_beta();
  57. if (beta <= 0) {
  58. MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
  59. }
  60. costmodel_beta_ = beta;
  61. MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
  62. // COST_MODEL_GAMMA
  63. auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
  64. if ((gamma < 0) || (gamma > 1)) {
  65. MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
  66. }
  67. COST_MODEL_GAMMA = gamma;
  68. MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << ".";
  69. // COST_MODEL_SIMPLIFY_CALCULATION
  70. auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal();
  71. COST_MODEL_SIMPLIFY_CALCULATION = simplify;
  72. if (COST_MODEL_SIMPLIFY_CALCULATION) {
  73. MS_LOG(INFO) << "costmodel_simplify_cal: true.";
  74. } else {
  75. MS_LOG(INFO) << "costmodel_simplify_cal: false.";
  76. }
  77. // COST_MODEL_COMMUNI_THRESHOLD
  78. auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
  79. if (communi_threshold < 0) {
  80. MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
  81. }
  82. COST_MODEL_COMMUNI_THRESHOLD = communi_threshold;
  83. MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << ".";
  84. // COST_MODEL_COMMUNI_CONST
  85. auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const();
  86. if (communi_const < 0) {
  87. MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
  88. }
  89. COST_MODEL_COMMUNI_CONST = communi_const;
  90. MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << ".";
  91. // COST_MODEL_COMMUNI_BIAS
  92. auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
  93. if (communi_bias < 0) {
  94. MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
  95. }
  96. COST_MODEL_COMMUNI_BIAS = communi_bias;
  97. MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << ".";
  98. // TENSOR_SLICE_ALIGNMENT_ENABLE
  99. auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
  100. TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable;
  101. if (TENSOR_SLICE_ALIGNMENT_ENABLE) {
  102. MS_LOG(INFO) << "tensor_slice_align_enable: true.";
  103. } else {
  104. MS_LOG(INFO) << "tensor_slice_align_enable: false.";
  105. }
  106. // TENSOR_SLICE_ALIGNMENT_SIZE
  107. auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
  108. if (align_size == 0) {
  109. MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
  110. }
  111. TENSOR_SLICE_ALIGNMENT_SIZE = align_size;
  112. MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << ".";
  113. // FULLY_USE_DEVICES
  114. auto fully_devices = CostModelContext::GetInstance()->fully_use_device();
  115. FULLY_USE_DEVICES = fully_devices;
  116. if (FULLY_USE_DEVICES) {
  117. MS_LOG(INFO) << "fully_use_devices: true.";
  118. } else {
  119. MS_LOG(INFO) << "fully_use_devices: false.";
  120. }
  121. // ELEMENTWISE_OP_STRA_FOLLOW
  122. auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
  123. ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow;
  124. if (ELEMENTWISE_OP_STRA_FOLLOW) {
  125. MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
  126. } else {
  127. MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
  128. }
  129. }
  130. void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
  131. for (auto it = ops_.begin(); it != ops_.end();) {
  132. if ((*it) == op) {
  133. it = ops_.erase(it);
  134. } else {
  135. ++it;
  136. }
  137. }
  138. }
  139. bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
  140. struct IsInGraph {
  141. const OperatorInfoPtr test_;
  142. explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
  143. bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
  144. };
  145. return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
  146. }
  147. bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
  148. for (auto &edge_pair : edges_) {
  149. auto edges = edge_pair.second;
  150. for (auto &edge : edges) {
  151. MS_EXCEPTION_IF_NULL(edge);
  152. bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
  153. (edge->next_op_input_index() == input_index);
  154. if (bool_result) {
  155. return true;
  156. }
  157. }
  158. }
  159. return false;
  160. }
  161. std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
  162. std::vector<OperatorInfoPtr> alive_ops) {
  163. std::map<OperatorInfoPtr, bool> visited;
  164. for (auto &op : alive_ops) {
  165. visited[op] = false;
  166. }
  167. MS_LOG(INFO) << "visited: " << visited.size() << ".";
  168. for (auto &op : alive_ops) {
  169. if ((!visited[op]) && op->is_alive()) {
  170. std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
  171. MS_EXCEPTION_IF_NULL(new_component);
  172. new_component->SetDeviceMemoryAndCostParameter();
  173. DFS(op, &visited, new_component);
  174. connected_compoents_.push_back(new_component);
  175. }
  176. }
  177. return connected_compoents_;
  178. }
  179. void CostGraph::DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
  180. const std::shared_ptr<CostGraph> &component) {
  181. MS_EXCEPTION_IF_NULL(visited);
  182. MS_EXCEPTION_IF_NULL(component);
  183. visited->at(current_op) = true;
  184. component->AddOperator(current_op);
  185. for (auto &edge : current_op->succ_edges()) {
  186. bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
  187. (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
  188. if (bool_test) {
  189. component->AddEdge(current_op, edge->next_operator(), edge);
  190. DFS(edge->next_operator(), visited, component);
  191. }
  192. }
  193. for (auto &edge : current_op->prev_edges()) {
  194. bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
  195. (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
  196. if (bool_test) {
  197. component->AddEdge(edge->prev_operator(), current_op, edge);
  198. DFS(edge->prev_operator(), visited, component);
  199. }
  200. }
  201. }
  202. // Create final cost list for the graph: u --> v
  203. CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
  204. const OperatorInfoPtr &v) {
  205. MS_EXCEPTION_IF_NULL(u);
  206. MS_EXCEPTION_IF_NULL(v);
  207. MS_EXCEPTION_IF_NULL(e);
  208. CostPtrList ret;
  209. for (const auto &u_strategy : u->GetStrategyCost()) {
  210. for (const auto &v_strategy : v->GetStrategyCost()) {
  211. MS_EXCEPTION_IF_NULL(u_strategy);
  212. MS_EXCEPTION_IF_NULL(v_strategy);
  213. auto u_strategy_ptr = u_strategy->strategy_ptr;
  214. auto v_strategy_ptr = v_strategy->strategy_ptr;
  215. CostPtrList clist1 = u_strategy->cost_list;
  216. CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
  217. CostPtrList clist3 = v_strategy->cost_list;
  218. for (const auto &cost1 : clist1) {
  219. for (const auto &cost2 : clist2) {
  220. for (const auto &cost3 : clist3) {
  221. MS_EXCEPTION_IF_NULL(cost1);
  222. MS_EXCEPTION_IF_NULL(cost2);
  223. MS_EXCEPTION_IF_NULL(cost3);
  224. double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
  225. double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
  226. double commmunication =
  227. cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
  228. double communication_without_para = cost1->communication_without_parameter_ +
  229. cost2->communication_without_parameter_ +
  230. cost3->communication_without_parameter_;
  231. auto decision =
  232. std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
  233. auto cost = std::make_shared<Cost>(computation, commmunication, decision);
  234. MS_EXCEPTION_IF_NULL(cost);
  235. cost->communication_without_parameter_ = communication_without_para;
  236. cost->communication_with_partial_para_ =
  237. communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para);
  238. cost->memory_with_reuse_ = memory;
  239. ret.push_back(cost);
  240. }
  241. }
  242. }
  243. }
  244. }
  245. SimplifyForDreasingCommunicationWithPartialPara(&ret);
  246. return ret;
  247. }
  248. // Create final cost list for the graph containing a signle node: u
  249. CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
  250. MS_EXCEPTION_IF_NULL(u);
  251. CostPtrList ret;
  252. for (const auto &u_strategy : u->GetStrategyCost()) {
  253. MS_EXCEPTION_IF_NULL(u_strategy);
  254. auto u_strategy_ptr = u_strategy->strategy_ptr;
  255. CostPtrList clist1 = u_strategy->cost_list;
  256. for (const auto &cost1 : clist1) {
  257. MS_EXCEPTION_IF_NULL(cost1);
  258. auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
  259. auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
  260. MS_EXCEPTION_IF_NULL(new_cost);
  261. new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
  262. new_cost->communication_with_partial_para_ =
  263. cost1->communication_without_parameter_ +
  264. COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
  265. new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
  266. ret.push_back(new_cost);
  267. }
  268. }
  269. SimplifyForDreasingCommunicationWithPartialPara(&ret);
  270. return ret;
  271. }
  272. CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) {
  273. CostPtrList after_mem_filter;
  274. // Filter out the valid costs
  275. for (auto &a_cost : cost_list) {
  276. if (a_cost->memory_with_reuse_ <= memory) {
  277. after_mem_filter.emplace_back(std::move(a_cost));
  278. }
  279. }
  280. std::function<CostPtr(CostPtr, const CostPtr &)> LocalCompare = [&](CostPtr init, const CostPtr &cost_x) {
  281. MS_EXCEPTION_IF_NULL(cost_x);
  282. if (init == nullptr || cost_x->computation_cost_ < memory) {
  283. init = cost_x;
  284. }
  285. return init;
  286. };
  287. CostPtr ret = nullptr;
  288. return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare);
  289. }
  290. CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
  291. // Select the cost with minimum training time. Currently, the training time is modeled as =
  292. // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
  293. if (cost_list.empty()) {
  294. MS_LOG(ERROR) << "Final cost list is null.";
  295. return nullptr;
  296. }
  297. CostPtrList after_mem_filter;
  298. double minimum_memory = DBL_MAX;
  299. // Filter out the valid costs.
  300. for (auto &a_cost : cost_list) {
  301. if (a_cost->memory_with_reuse_ <= memory) {
  302. after_mem_filter.emplace_back(std::move(a_cost));
  303. } else if (a_cost->memory_with_reuse_ < minimum_memory) {
  304. minimum_memory = a_cost->memory_with_reuse_;
  305. }
  306. }
  307. if (after_mem_filter.empty()) {
  308. MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
  309. << ", the memory capacity is: " << memory << ".";
  310. return nullptr;
  311. }
  312. // Init the returned value with first cost.
  313. CostPtr ret = after_mem_filter[0];
  314. double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
  315. MS_LOG(INFO) << "Cost 0: "
  316. << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
  317. << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
  318. << ", communication_cost_: " << ret->communication_cost_
  319. << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
  320. MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
  321. for (size_t i = 1; i < after_mem_filter.size(); ++i) {
  322. MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
  323. MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
  324. << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
  325. << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
  326. << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
  327. << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
  328. << ".";
  329. auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
  330. costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
  331. MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
  332. if (minimum > tmp) {
  333. minimum = tmp;
  334. ret = after_mem_filter[i];
  335. MS_LOG(INFO) << "Selected: " << i;
  336. }
  337. }
  338. return ret;
  339. }
  340. CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
  341. double available_memory) {
  342. CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
  343. double minimum = DBL_MAX, total_memory = 0.0;
  344. CostPtrList ret(all_cost_list.size(), nullptr);
  345. // Check whether valid costs exist.
  346. for (size_t i = 0; i < all_cost_list.size(); ++i) {
  347. if (all_cost_list[i][0] == nullptr) {
  348. MS_LOG(ERROR) << "The cost list " << i << " is empty.";
  349. return ret;
  350. } else {
  351. double memory_i_cost = DBL_MAX;
  352. for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
  353. if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
  354. memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
  355. }
  356. }
  357. total_memory += memory_i_cost;
  358. }
  359. }
  360. if (total_memory >= available_memory) {
  361. MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
  362. << ", minimum strategy cost: " << total_memory << ".";
  363. return selected_cost_list;
  364. }
  365. std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
  366. &available_memory, this](size_t k) {
  367. if (k == all_cost_list.size()) {
  368. double tmp_memory = 0.0, tmp_minimum = 0.0;
  369. for (size_t i = 0; i < selected_cost_list.size(); ++i) {
  370. MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
  371. tmp_memory += selected_cost_list[i]->memory_with_reuse_;
  372. tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
  373. costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
  374. }
  375. MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
  376. << ".";
  377. if (tmp_memory < available_memory && tmp_minimum < minimum) {
  378. ret = selected_cost_list;
  379. minimum = tmp_minimum;
  380. MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
  381. }
  382. return;
  383. }
  384. MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
  385. for (auto &c : all_cost_list[k]) {
  386. selected_cost_list[k] = c;
  387. recursive(k + 1);
  388. }
  389. };
  390. recursive(0);
  391. return ret;
  392. }
  393. Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
  394. MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
  395. auto connected_components = ConstructConnectedComponents(alive_ops);
  396. MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
  397. std::vector<CostPtrList> all_list;
  398. for (size_t j = 0; j < connected_components.size(); ++j) {
  399. auto one_component = connected_components[j];
  400. MS_EXCEPTION_IF_NULL(one_component);
  401. if (one_component->GetOperators().size() == 1) {
  402. MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
  403. auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
  404. all_list.push_back(cost_list);
  405. } else if (one_component->GetOperators().size() == 2) {
  406. MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
  407. OperatorInfoPtr u, v;
  408. auto first_op = one_component->GetOperators()[0];
  409. auto second_op = one_component->GetOperators()[1];
  410. MS_EXCEPTION_IF_NULL(first_op);
  411. MS_EXCEPTION_IF_NULL(second_op);
  412. if (!first_op->GetAliveSuccEdges().empty() &&
  413. first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
  414. u = first_op;
  415. v = second_op;
  416. } else if (!second_op->GetAliveSuccEdges().empty() &&
  417. second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
  418. u = second_op;
  419. v = first_op;
  420. } else {
  421. MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
  422. << ", " << second_op->GetAliveSuccEdges().size() << ".";
  423. }
  424. MS_EXCEPTION_IF_NULL(u);
  425. auto e = u->GetAliveSuccEdges()[0];
  426. auto cost_list = one_component->CreateFinalCostList(u, e, v);
  427. all_list.push_back(cost_list);
  428. } else {
  429. MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
  430. << " operators in a component in the final graph.";
  431. }
  432. }
  433. //
  434. auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
  435. for (size_t k = 0; k < selected_cost_list.size(); ++k) {
  436. auto selected_cost = selected_cost_list[k];
  437. if (selected_cost == nullptr) {
  438. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  439. return FAILED;
  440. }
  441. MS_EXCEPTION_IF_NULL(connected_components[k]);
  442. if (connected_components[k]->GetOperators().size() == 1) {
  443. auto u = connected_components[k]->GetOperators()[0];
  444. auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  445. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  446. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  447. } else if (connected_components[k]->GetOperators().size() == 2) {
  448. OperatorInfoPtr u = nullptr, v = nullptr;
  449. auto first_op = connected_components[k]->GetOperators()[0];
  450. auto second_op = connected_components[k]->GetOperators()[1];
  451. MS_EXCEPTION_IF_NULL(first_op);
  452. MS_EXCEPTION_IF_NULL(second_op);
  453. if (!first_op->GetAliveSuccEdges().empty() &&
  454. first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
  455. u = first_op;
  456. v = second_op;
  457. } else if (!second_op->GetAliveSuccEdges().empty() &&
  458. second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
  459. u = second_op;
  460. v = first_op;
  461. }
  462. MS_EXCEPTION_IF_NULL(u);
  463. auto e = u->GetAliveSuccEdges()[0];
  464. MS_EXCEPTION_IF_NULL(v);
  465. MS_EXCEPTION_IF_NULL(e);
  466. MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
  467. auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
  468. MS_EXCEPTION_IF_NULL(decision);
  469. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
  470. v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
  471. e->set_selected_cost(decision->middle_cost_);
  472. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  473. }
  474. }
  475. return SUCCESS;
  476. }
  477. // searching the strategy for the final eliminated graph
  478. Status CostGraph::SearchStrategy() {
  479. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
  480. std::vector<OperatorInfoPtr> alive_ops;
  481. (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
  482. MS_EXCEPTION_IF_NULL(op);
  483. if (op->is_alive()) {
  484. alive_ops.push_back(op);
  485. }
  486. });
  487. if (alive_ops.size() > 2) {
  488. return SearchStrategyForMultiNodeFinalGraph(alive_ops);
  489. } else if (alive_ops.size() == 1) {
  490. MS_LOG(INFO) << "There are 1 single node in the final graph.";
  491. OperatorInfoPtr u = alive_ops[0];
  492. auto cost_list = CreateFinalSingleCostList(u);
  493. auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
  494. if (cost == nullptr) {
  495. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  496. return FAILED;
  497. }
  498. MS_EXCEPTION_IF_NULL(u);
  499. MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
  500. auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  501. MS_EXCEPTION_IF_NULL(decision);
  502. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  503. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
  504. return SUCCESS;
  505. } else {
  506. // In this case, the final graph should contains exactly 2 nodes.
  507. if (alive_ops.empty()) {
  508. MS_LOG(INFO) << "0 Operator in the final graph.";
  509. return SUCCESS;
  510. }
  511. OperatorInfoPtr u, v;
  512. MS_EXCEPTION_IF_NULL(alive_ops[0]);
  513. MS_EXCEPTION_IF_NULL(alive_ops[1]);
  514. if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
  515. alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
  516. u = alive_ops[0];
  517. v = alive_ops[1];
  518. } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
  519. alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
  520. u = alive_ops[1];
  521. v = alive_ops[0];
  522. } else {
  523. if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
  524. MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
  525. << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
  526. } else {
  527. // In this case, the final graph consists of two single nodes
  528. MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
  529. std::vector<CostPtrList> all_list;
  530. auto connected_components = ConstructConnectedComponents(alive_ops);
  531. MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
  532. for (size_t i = 0; i < connected_components.size(); ++i) {
  533. MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
  534. auto one_component = connected_components[i];
  535. MS_EXCEPTION_IF_NULL(one_component);
  536. auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
  537. all_list.push_back(cost_list);
  538. }
  539. auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
  540. for (size_t k = 0; k < selected_cost_list.size(); ++k) {
  541. auto selected_cost = selected_cost_list[k];
  542. if (selected_cost == nullptr) {
  543. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  544. return FAILED;
  545. }
  546. MS_EXCEPTION_IF_NULL(connected_components[k]);
  547. auto one_operator = connected_components[k]->GetOperators()[0];
  548. MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
  549. auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  550. MS_EXCEPTION_IF_NULL(decision);
  551. one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  552. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  553. }
  554. return SUCCESS;
  555. }
  556. }
  557. MS_LOG(INFO) << "There are 2 nodes in the final graph.";
  558. // In this case, the finale graph is exactly of the form: u --> v
  559. MS_EXCEPTION_IF_NULL(u);
  560. MS_EXCEPTION_IF_NULL(v);
  561. auto e = u->GetAliveSuccEdges()[0];
  562. MS_EXCEPTION_IF_NULL(e);
  563. auto cost_list = CreateFinalCostList(u, e, v);
  564. auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
  565. if (cost == nullptr) {
  566. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  567. return FAILED;
  568. }
  569. MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
  570. auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
  571. MS_EXCEPTION_IF_NULL(decision);
  572. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
  573. v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
  574. e->set_selected_cost(decision->middle_cost_);
  575. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
  576. return SUCCESS;
  577. }
  578. }
  579. // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
  580. // return the v and the edge u --> v
  581. OperatorInfoPtr CostGraph::CheckOpElimination() const {
  582. for (auto &op : ops_) {
  583. bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
  584. if (bool_test) {
  585. if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
  586. return op;
  587. }
  588. }
  589. }
  590. return nullptr;
  591. }
  592. // Check the graph whether an EdgeElimination can be performed
  593. std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
  594. for (auto &op : ops_) {
  595. MS_EXCEPTION_IF_NULL(op);
  596. if (!op->is_alive()) continue;
  597. std::map<void *, int> count;
  598. for (auto &edge : op->GetAliveSuccEdges()) {
  599. MS_EXCEPTION_IF_NULL(edge);
  600. auto v = edge->next_operator();
  601. count[v.get()]++;
  602. }
  603. for (auto &pair : count) {
  604. auto *op_ptr = pair.first;
  605. int op_count = pair.second;
  606. if (op_count > 1) {
  607. std::vector<std::shared_ptr<Edge>> ret;
  608. for (auto &edge : op->GetAliveSuccEdges()) {
  609. MS_EXCEPTION_IF_NULL(edge);
  610. if (edge->next_operator().get() == op_ptr) {
  611. ret.push_back(edge);
  612. }
  613. }
  614. return ret;
  615. }
  616. }
  617. }
  618. return {};
  619. }
  620. // Check the graph whether a MergeElimination can be performed
  621. OperatorInfoPtr CostGraph::CheckMergeElimination() const {
  622. for (auto &op : ops_) {
  623. MS_EXCEPTION_IF_NULL(op);
  624. bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
  625. if (bool_test) {
  626. auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
  627. MS_EXCEPTION_IF_NULL(next_op);
  628. if (!next_op->GetAlivePrevEdges().empty()) {
  629. return op;
  630. }
  631. }
  632. }
  633. return nullptr;
  634. }
  635. // Check the graph whether a ContractElimination can be performed
  636. OperatorInfoPtr CostGraph::CheckContractElimination() const {
  637. for (auto &op : ops_) {
  638. MS_EXCEPTION_IF_NULL(op);
  639. bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
  640. if (bool_test) {
  641. auto edge = op->GetAlivePrevEdges()[0];
  642. MS_EXCEPTION_IF_NULL(edge);
  643. auto prev_op = edge->prev_operator();
  644. MS_EXCEPTION_IF_NULL(prev_op);
  645. if (!prev_op->GetAliveSuccEdges().empty()) {
  646. return op;
  647. }
  648. }
  649. }
  650. return nullptr;
  651. }
  652. // Check the graph whether a TriangleElimination can be performed
  653. std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
  654. for (auto &op : ops_) {
  655. MS_EXCEPTION_IF_NULL(op);
  656. bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
  657. if (bool_test) {
  658. auto edge1 = op->GetAliveSuccEdges()[0];
  659. auto edge2 = op->GetAliveSuccEdges()[1];
  660. MS_EXCEPTION_IF_NULL(edge1);
  661. MS_EXCEPTION_IF_NULL(edge2);
  662. auto first_op = edge1->next_operator();
  663. auto second_op = edge2->next_operator();
  664. MS_EXCEPTION_IF_NULL(first_op);
  665. for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
  666. if (first_op_succ_edge->next_operator() == second_op) {
  667. return {op, first_op_succ_edge};
  668. }
  669. }
  670. MS_EXCEPTION_IF_NULL(second_op);
  671. for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
  672. if (second_op_succ_edge->next_operator() == first_op) {
  673. return {op, second_op_succ_edge};
  674. }
  675. }
  676. }
  677. }
  678. return {nullptr, nullptr};
  679. }
  680. // Check the graph whether a StarElimination can be performed.
  681. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
  682. OperatorInfoPtr CostGraph::CheckStarElimination() const {
  683. for (auto &op : ops_) {
  684. MS_EXCEPTION_IF_NULL(op);
  685. bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
  686. if (bool_test) {
  687. return op;
  688. }
  689. }
  690. return nullptr;
  691. }
  692. // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
  693. // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
  694. std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) {
  695. // in this case, the operators are organised in the form of u-->op-->v, and the goal
  696. // is to eliminate 'op'.
  697. MS_EXCEPTION_IF_NULL(op);
  698. MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
  699. auto edge_u_op = op->GetAlivePrevEdges()[0];
  700. auto edge_op_v = op->GetAliveSuccEdges()[0];
  701. MS_EXCEPTION_IF_NULL(edge_u_op);
  702. MS_EXCEPTION_IF_NULL(edge_op_v);
  703. auto u = edge_u_op->prev_operator();
  704. auto v = edge_op_v->next_operator();
  705. std::vector<size_t> output_indexs, input_indexs;
  706. size_t output_index, input_index;
  707. MS_EXCEPTION_IF_NULL(u);
  708. MS_EXCEPTION_IF_NULL(v);
  709. std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
  710. std::shared_ptr<Edge> new_edge;
  711. if (edge_u_op->is_combined()) {
  712. output_indexs = edge_u_op->prev_op_output_indexs();
  713. } else {
  714. output_index = edge_u_op->prev_op_output_index();
  715. output_indexs.push_back(output_index);
  716. }
  717. if (edge_op_v->is_combined()) {
  718. input_indexs = edge_op_v->next_op_input_indexs();
  719. } else {
  720. input_index = edge_op_v->next_op_input_index();
  721. input_indexs.push_back(input_index);
  722. }
  723. if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
  724. new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
  725. } else {
  726. new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
  727. }
  728. MS_EXCEPTION_IF_NULL(new_edge);
  729. new_edge->set_pre_op_output(edge_u_op->prev_op_output());
  730. new_edge->set_next_op_input(edge_op_v->next_op_input());
  731. new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
  732. u->ReplaceSuccEdge(op, new_edge);
  733. v->ReplacePreEdge(op, new_edge);
  734. op->SetNotAlive();
  735. MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
  736. return new_edge;
  737. }
  738. // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
  739. // and sets new costlist for the new edge.
  740. std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) {
  741. MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
  742. MS_EXCEPTION_IF_NULL(edges[0]);
  743. auto u = edges[0]->prev_operator();
  744. auto v = edges[0]->next_operator();
  745. MS_EXCEPTION_IF_NULL(u);
  746. MS_EXCEPTION_IF_NULL(v);
  747. std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
  748. std::vector<size_t> output_indexs, input_indexs;
  749. for (auto &edge : edges) {
  750. MS_EXCEPTION_IF_NULL(edge);
  751. if (edge->is_combined()) {
  752. auto from_output_indexs = edge->prev_op_output_indexs();
  753. auto from_input_indexs = edge->next_op_input_indexs();
  754. (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
  755. (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
  756. } else {
  757. output_indexs.push_back(edge->prev_op_output_index());
  758. input_indexs.push_back(edge->next_op_input_index());
  759. }
  760. }
  761. std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
  762. MS_EXCEPTION_IF_NULL(new_edge);
  763. new_edge->set_pre_op_output(edges[0]->prev_op_output());
  764. new_edge->set_next_op_input(edges[0]->next_op_input());
  765. new_edge->EdgeEliminationSetNewCost(u, edges, v);
  766. u->ReplaceSuccEdges(v, new_edge);
  767. v->ReplacePreEdges(u, new_edge);
  768. MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
  769. return new_edge;
  770. }
  771. // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
  772. // for this contract under the strategy 'op_strategy'
  773. void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
  774. const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
  775. const CostPtrList &tar_cost_list,
  776. CostPtrList *const tar_cost_list_new) {
  777. for (size_t i = 0; i < op_cost_list.size(); ++i) {
  778. auto &op_cost = op_cost_list[i];
  779. MS_EXCEPTION_IF_NULL(op_cost);
  780. for (size_t j = 0; j < edge_cost_list.size(); ++j) {
  781. auto &edge_cost = edge_cost_list[j];
  782. MS_EXCEPTION_IF_NULL(edge_cost);
  783. for (size_t k = 0; k < tar_cost_list.size(); ++k) {
  784. auto &tar_cost = tar_cost_list[k];
  785. MS_EXCEPTION_IF_NULL(tar_cost);
  786. double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
  787. double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
  788. double communication =
  789. op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
  790. double communication_without_para = op_cost->communication_without_parameter_ +
  791. edge_cost->communication_without_parameter_ +
  792. tar_cost->communication_without_parameter_;
  793. auto decision =
  794. std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
  795. auto new_cost = std::make_shared<Cost>(computation, communication, decision);
  796. MS_EXCEPTION_IF_NULL(new_cost);
  797. new_cost->communication_without_parameter_ = communication_without_para;
  798. new_cost->communication_with_partial_para_ =
  799. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  800. new_cost->memory_with_reuse_ = memory;
  801. MS_EXCEPTION_IF_NULL(tar_cost_list_new);
  802. tar_cost_list_new->emplace_back(std::move(new_cost));
  803. }
  804. }
  805. }
  806. }
  807. // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
  808. // target_op
  809. OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
  810. MS_EXCEPTION_IF_NULL(op);
  811. auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
  812. auto edge_ptr = op->GetAliveSuccEdges()[0];
  813. MS_EXCEPTION_IF_NULL(target_op);
  814. MS_EXCEPTION_IF_NULL(edge_ptr);
  815. MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
  816. bool valid = false;
  817. for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
  818. MS_EXCEPTION_IF_NULL(tar_stra_cost);
  819. auto tar_stra = tar_stra_cost->strategy_ptr;
  820. auto tar_clist_origin = tar_stra_cost->cost_list;
  821. CostPtrList tar_clist_new;
  822. for (auto &op_stra_cost : op->GetStrategyCost()) {
  823. MS_EXCEPTION_IF_NULL(op_stra_cost);
  824. auto op_stra = op_stra_cost->strategy_ptr;
  825. auto op_clist = op_stra_cost->cost_list;
  826. auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
  827. CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
  828. }
  829. SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
  830. // Set the new costlist w.r.t the strategy
  831. tar_stra_cost->cost_list = tar_clist_new;
  832. if ((!valid) && (!tar_clist_new.empty())) {
  833. valid = true;
  834. }
  835. }
  836. if (!valid) {
  837. MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
  838. }
  839. op->SetNotAlive();
  840. MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
  841. return target_op;
  842. }
  843. // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
  844. // for this contract under the strategy 'contract_op_stra'
  845. void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
  846. const CostPtrList &contract_op_cost_list,
  847. const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
  848. const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) {
  849. for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
  850. auto &contract_op_cost = contract_op_cost_list[i];
  851. MS_EXCEPTION_IF_NULL(contract_op_cost);
  852. for (size_t j = 0; j < edge_cost_list.size(); ++j) {
  853. auto &edge_cost = edge_cost_list[j];
  854. MS_EXCEPTION_IF_NULL(edge_cost);
  855. for (size_t k = 0; k < tar_cost_list.size(); ++k) {
  856. auto &tar_cost = tar_cost_list[k];
  857. MS_EXCEPTION_IF_NULL(tar_cost);
  858. double computation =
  859. contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
  860. double memory =
  861. contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
  862. double communication =
  863. contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
  864. double communication_without_para = contract_op_cost->communication_without_parameter_ +
  865. edge_cost->communication_without_parameter_ +
  866. tar_cost->communication_without_parameter_;
  867. auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
  868. target_op_stra, tar_cost);
  869. auto new_cost = std::make_shared<Cost>(computation, communication, decision);
  870. new_cost->communication_without_parameter_ = communication_without_para;
  871. new_cost->communication_with_partial_para_ =
  872. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  873. new_cost->memory_with_reuse_ = memory;
  874. tar_cost_list_new->emplace_back(std::move(new_cost));
  875. }
  876. }
  877. }
  878. }
  879. // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
  880. // target_op
  881. OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
  882. MS_EXCEPTION_IF_NULL(op);
  883. auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
  884. auto edge_ptr = op->GetAlivePrevEdges()[0];
  885. MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
  886. bool valid = false;
  887. for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
  888. MS_EXCEPTION_IF_NULL(tar_stra_cost);
  889. auto tar_stra = tar_stra_cost->strategy_ptr;
  890. auto tar_clist_origin = tar_stra_cost->cost_list;
  891. CostPtrList tar_clist_new;
  892. for (auto &op_stra_cost : op->GetStrategyCost()) {
  893. MS_EXCEPTION_IF_NULL(op_stra_cost);
  894. auto op_stra = op_stra_cost->strategy_ptr;
  895. auto op_clist = op_stra_cost->cost_list;
  896. auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
  897. CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
  898. }
  899. SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
  900. // Set the new costlist w.r.t the strategy
  901. tar_stra_cost->cost_list = tar_clist_new;
  902. if ((!valid) && (!tar_clist_new.empty())) {
  903. valid = true;
  904. }
  905. }
  906. if (!valid) {
  907. MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
  908. }
  909. op->SetNotAlive();
  910. MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
  911. return target_op;
  912. }
  913. void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
  914. StrategyPtr right_op_stra, const CostPtr &right_op_cost,
  915. const CostPtrList &elimi_op_clist,
  916. const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
  917. const CostPtrList &left_node_clist_origin,
  918. CostPtrList *left_node_clist_new) {
  919. MS_EXCEPTION_IF_NULL(right_edge_cost);
  920. MS_EXCEPTION_IF_NULL(right_op_cost);
  921. MS_EXCEPTION_IF_NULL(left_node_clist_new);
  922. for (auto &elimi_op_cost : elimi_op_clist) {
  923. MS_EXCEPTION_IF_NULL(elimi_op_cost);
  924. for (auto &left_edge_cost : left_edge_clist) {
  925. MS_EXCEPTION_IF_NULL(left_edge_cost);
  926. for (auto &left_node_cost : left_node_clist_origin) {
  927. MS_EXCEPTION_IF_NULL(left_node_cost);
  928. double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
  929. left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
  930. double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
  931. left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
  932. double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
  933. left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
  934. double new_commu_without =
  935. elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
  936. left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
  937. auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost,
  938. right_edge_cost, left_op_stra, left_node_cost);
  939. auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
  940. new_cost->communication_without_parameter_ = new_commu_without;
  941. new_cost->communication_with_partial_para_ =
  942. new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
  943. new_cost->memory_with_reuse_ = new_memory;
  944. left_node_clist_new->emplace_back(std::move(new_cost));
  945. }
  946. }
  947. }
  948. }
  949. void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
  950. const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
  951. const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
  952. const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
  953. const CostPtrList &left_node_clist_origin,
  954. CostPtrList *left_node_clist_new) {
  955. MS_EXCEPTION_IF_NULL(elimi_op);
  956. for (auto &right_node_cost : right_node_clist) {
  957. MS_EXCEPTION_IF_NULL(right_node_cost);
  958. for (auto &right_edge_cost : right_edge_clist) {
  959. MS_EXCEPTION_IF_NULL(right_edge_cost);
  960. CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
  961. elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
  962. left_node_clist_new);
  963. }
  964. }
  965. }
  966. OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
  967. const std::shared_ptr<Edge> &edge_left_right) {
  968. MS_EXCEPTION_IF_NULL(edge_left_right);
  969. MS_EXCEPTION_IF_NULL(elimi_op);
  970. MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
  971. auto left_node = edge_left_right->prev_operator();
  972. auto right_node = edge_left_right->next_operator();
  973. auto left_edge = elimi_op->GetAliveSuccEdges()[0];
  974. auto right_edge = elimi_op->GetAliveSuccEdges()[1];
  975. MS_EXCEPTION_IF_NULL(left_node);
  976. MS_EXCEPTION_IF_NULL(right_node);
  977. MS_EXCEPTION_IF_NULL(left_edge);
  978. MS_EXCEPTION_IF_NULL(right_edge);
  979. MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
  980. MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
  981. if (left_edge->next_operator() != left_node) {
  982. auto tmp = left_edge;
  983. left_edge = right_edge;
  984. right_edge = tmp;
  985. }
  986. bool valid = false;
  987. for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
  988. MS_EXCEPTION_IF_NULL(left_node_stra_cost);
  989. auto left_node_stra = left_node_stra_cost->strategy_ptr;
  990. auto left_node_clist_origin = left_node_stra_cost->cost_list;
  991. CostPtrList left_node_clist_new;
  992. for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
  993. MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
  994. auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
  995. auto elimi_op_clist = elimi_op_stra_cost->cost_list;
  996. auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
  997. for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
  998. MS_EXCEPTION_IF_NULL(right_node_stra_cost);
  999. auto right_node_stra = right_node_stra_cost->strategy_ptr;
  1000. auto right_node_clist = right_node_stra_cost->cost_list;
  1001. auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
  1002. CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
  1003. right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
  1004. &left_node_clist_new);
  1005. }
  1006. }
  1007. SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new);
  1008. // Set the new costlist w.r.t the strategy
  1009. left_node_stra_cost->cost_list = left_node_clist_new;
  1010. if ((!valid) && (!left_node_clist_new.empty())) {
  1011. valid = true;
  1012. }
  1013. }
  1014. if (!valid) {
  1015. MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed.";
  1016. }
  1017. elimi_op->SetNotAlive();
  1018. MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
  1019. return left_node;
  1020. }
  1021. void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
  1022. const CostPtrList &first_succ_node_clist,
  1023. const CostPtrList &first_succ_edge_clist,
  1024. const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
  1025. std::vector<StrategyPtr> succ_nodes_stras,
  1026. CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
  1027. CostPtrList *first_succ_node_clist_new) {
  1028. for (auto &first_succ_node_cost : first_succ_node_clist) {
  1029. for (auto &first_succ_edge_cost : first_succ_edge_clist) {
  1030. for (auto &merged_node_cost : merged_op_clist) {
  1031. MS_EXCEPTION_IF_NULL(merged_node_cost);
  1032. succ_nodes_stras[0] = first_succ_node_stra;
  1033. succ_edges_costs[0] = first_succ_edge_cost;
  1034. succ_nodes_costs[0] = first_succ_node_cost;
  1035. double computation_cost = merged_node_cost->computation_cost_,
  1036. memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
  1037. commu_without = merged_node_cost->communication_without_parameter_;
  1038. for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
  1039. MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
  1040. if (i == 0) {
  1041. computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
  1042. memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
  1043. commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
  1044. commu_without += succ_edges_costs[i]->communication_without_parameter_ +
  1045. succ_nodes_costs[i]->communication_without_parameter_;
  1046. } else {
  1047. computation_cost += succ_edges_costs[i]->computation_cost_;
  1048. memory_cost += succ_edges_costs[i]->memory_with_reuse_;
  1049. commu_cost += succ_edges_costs[i]->communication_cost_;
  1050. commu_without += succ_edges_costs[i]->communication_without_parameter_;
  1051. }
  1052. }
  1053. auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
  1054. succ_nodes_stras, succ_nodes_costs);
  1055. auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
  1056. new_cost->communication_without_parameter_ = commu_without;
  1057. new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
  1058. new_cost->memory_with_reuse_ = memory_cost;
  1059. first_succ_node_clist_new->emplace_back(std::move(new_cost));
  1060. }
  1061. }
  1062. }
  1063. }
  1064. void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
  1065. const StrategyPtr &first_succ_node_stra,
  1066. const CostPtrList &first_succ_node_clist,
  1067. const CostPtrList &first_succ_edge_clist,
  1068. const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
  1069. CostPtrList *first_succ_node_clist_new) {
  1070. std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
  1071. CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
  1072. std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
  1073. &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
  1074. &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
  1075. this](size_t k) {
  1076. if (k == succ_edges.size()) {
  1077. CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
  1078. merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
  1079. succ_nodes_costs, first_succ_node_clist_new);
  1080. return;
  1081. }
  1082. MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
  1083. << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
  1084. << ", merged_op_clist: " << merged_op_clist.size()
  1085. << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
  1086. auto succ_edge = succ_edges[k];
  1087. MS_EXCEPTION_IF_NULL(succ_edge);
  1088. auto succ_node = succ_edge->next_operator();
  1089. MS_EXCEPTION_IF_NULL(succ_node);
  1090. for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
  1091. MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
  1092. auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
  1093. auto succ_node_clist = succ_node_stra_cost->cost_list;
  1094. auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
  1095. for (auto &succ_node_cost : succ_node_clist) {
  1096. MS_EXCEPTION_IF_NULL(succ_node_cost);
  1097. for (auto &succ_edge_cost : succ_edge_clist) {
  1098. MS_EXCEPTION_IF_NULL(succ_edge_cost);
  1099. succ_nodes_stras[k] = succ_node_stra;
  1100. succ_edges_costs[k] = succ_edge_cost;
  1101. succ_nodes_costs[k] = succ_node_cost;
  1102. recursive(k + 1);
  1103. }
  1104. }
  1105. }
  1106. };
  1107. recursive(1);
  1108. }
  1109. std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) {
  1110. MS_EXCEPTION_IF_NULL(merged_op);
  1111. auto succ_edges = merged_op->GetAliveSuccEdges();
  1112. MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
  1113. for (auto &succ_edge : succ_edges) {
  1114. MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
  1115. MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
  1116. }
  1117. MS_EXCEPTION_IF_NULL(succ_edges[0]);
  1118. auto first_succ_node = succ_edges[0]->next_operator();
  1119. auto first_succ_edge = succ_edges[0];
  1120. bool valid = false;
  1121. // 'merged_op' is merged into first_node
  1122. MS_EXCEPTION_IF_NULL(first_succ_node);
  1123. for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
  1124. MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
  1125. auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
  1126. auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
  1127. CostPtrList first_succ_node_clist_new;
  1128. for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
  1129. MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
  1130. auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
  1131. auto merged_op_clist = merged_op_stra_cost->cost_list;
  1132. auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
  1133. CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
  1134. merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
  1135. }
  1136. SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new);
  1137. // Set the new costlist w.r.t the strategy
  1138. first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
  1139. if ((!valid) && (!first_succ_node_clist_new.empty())) {
  1140. valid = true;
  1141. }
  1142. }
  1143. if (!valid) {
  1144. MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed.";
  1145. }
  1146. merged_op->SetNotAlive();
  1147. MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
  1148. return succ_edges;
  1149. }
  1150. Status CostGraph::InitSelectedStrategy() {
  1151. for (auto &op : ops_) {
  1152. MS_EXCEPTION_IF_NULL(op);
  1153. auto result = op->InitSelectedStrategy(op->selected_strategy());
  1154. if (result != SUCCESS) {
  1155. return result;
  1156. }
  1157. }
  1158. return SUCCESS;
  1159. }
  1160. Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
  1161. for (auto &op : ops_) {
  1162. MS_EXCEPTION_IF_NULL(op);
  1163. const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
  1164. if ((output_parameter != 0) && (output_parameter != 1)) {
  1165. MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
  1166. return FAILED;
  1167. }
  1168. }
  1169. return SUCCESS;
  1170. }
  1171. Status CostGraph::CalculateOpsMemoryCost() {
  1172. for (auto &op : ops_) {
  1173. MS_EXCEPTION_IF_NULL(op);
  1174. if (op->CalculateMemoryCost() != SUCCESS) {
  1175. MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
  1176. return FAILED;
  1177. }
  1178. }
  1179. return SUCCESS;
  1180. }
  1181. Status CostGraph::CalculateEdgesMemoryCost() {
  1182. for (auto &edge_pair : edges_) {
  1183. const auto &edges = edge_pair.second;
  1184. for (auto &one_edge : edges) {
  1185. if (one_edge->CalculateMemoryCost() != SUCCESS) {
  1186. MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
  1187. return FAILED;
  1188. }
  1189. }
  1190. }
  1191. return SUCCESS;
  1192. }
  1193. OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
  1194. for (auto one_op : ops_) {
  1195. if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
  1196. if (one_op->refkey_parameter_name() == p_name) {
  1197. return one_op;
  1198. }
  1199. }
  1200. }
  1201. return nullptr;
  1202. }
  1203. Status CostGraph::CorrectOpsMemoryCost() {
  1204. for (auto &one_op : ops_) {
  1205. if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
  1206. if (one_op->GetAliveSuccEdges().size() > 1) {
  1207. // Filter out the case when the TmpIdentity being used by multiple operators
  1208. std::map<size_t, int> output_count;
  1209. for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
  1210. auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
  1211. output_count[output_index]++;
  1212. }
  1213. for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
  1214. auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
  1215. if (output_count[output_index] <= 1) {
  1216. continue;
  1217. }
  1218. auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
  1219. MS_EXCEPTION_IF_NULL(next_op);
  1220. auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
  1221. if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
  1222. MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
  1223. << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
  1224. return FAILED;
  1225. }
  1226. output_count[output_index]--;
  1227. }
  1228. }
  1229. }
  1230. }
  1231. return SUCCESS;
  1232. }
  1233. } // namespace parallel
  1234. } // namespace mindspore