You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_costmodel.cc 75 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <cstdlib>
  18. #include <iterator>
  19. #include <numeric>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "parallel/auto_parallel/graph_costmodel.h"
  24. #include "parallel/ops_info/reshape_info.h"
  25. #include "parallel/step_auto_parallel.h"
  26. namespace mindspore {
  27. namespace parallel {
  28. CostGraphPtr entire_costgraph = nullptr;
  29. size_t TOTAL_OPS = 0;
  30. double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA;
  31. bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
  32. double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY;
  33. double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
  34. double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST;
  35. double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS;
  36. bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
  37. size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
  38. bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
  39. bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
  40. bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
  41. int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
  42. constexpr char RESHAPEINFO[] = "ReshapeInfo";
  43. void CostGraph::SetDeviceMemoryAndCostParameter() {
  44. MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
  45. // DEVICE_MEMORY_CAPACITY
  46. auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
  47. if (device_memory <= 0) {
  48. MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
  49. }
  50. dev_memory_ = device_memory;
  51. DEVICE_MEMORY_CAPACITY = device_memory;
  52. MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << ".";
  53. // COST_MODEL_ALPHA
  54. auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
  55. if (alpha <= 0) {
  56. MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
  57. }
  58. costmodel_alpha_ = alpha;
  59. MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
  60. // COST_MODEL_BETA
  61. auto beta = CostModelContext::GetInstance()->costmodel_beta();
  62. if (beta <= 0) {
  63. MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
  64. }
  65. costmodel_beta_ = beta;
  66. MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
  67. // COST_MODEL_GAMMA
  68. auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
  69. if ((gamma < 0) || (gamma > 1)) {
  70. MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
  71. }
  72. COST_MODEL_GAMMA = gamma;
  73. MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << ".";
  74. // COST_MODEL_SIMPLIFY_CALCULATION
  75. auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal();
  76. COST_MODEL_SIMPLIFY_CALCULATION = simplify;
  77. if (COST_MODEL_SIMPLIFY_CALCULATION) {
  78. MS_LOG(INFO) << "costmodel_simplify_cal: true.";
  79. } else {
  80. MS_LOG(INFO) << "costmodel_simplify_cal: false.";
  81. }
  82. // COST_MODEL_COMMUNI_THRESHOLD
  83. auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
  84. if (communi_threshold < 0) {
  85. MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
  86. }
  87. COST_MODEL_COMMUNI_THRESHOLD = communi_threshold;
  88. MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << ".";
  89. // COST_MODEL_COMMUNI_CONST
  90. auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const();
  91. if (communi_const < 0) {
  92. MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
  93. }
  94. COST_MODEL_COMMUNI_CONST = communi_const;
  95. MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << ".";
  96. // COST_MODEL_COMMUNI_BIAS
  97. auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
  98. if (communi_bias < 0) {
  99. MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
  100. }
  101. COST_MODEL_COMMUNI_BIAS = communi_bias;
  102. MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << ".";
  103. // TENSOR_SLICE_ALIGNMENT_ENABLE
  104. auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
  105. TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable;
  106. if (TENSOR_SLICE_ALIGNMENT_ENABLE) {
  107. MS_LOG(INFO) << "tensor_slice_align_enable: true.";
  108. } else {
  109. MS_LOG(INFO) << "tensor_slice_align_enable: false.";
  110. }
  111. // TENSOR_SLICE_ALIGNMENT_SIZE
  112. auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
  113. if (align_size == 0) {
  114. MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
  115. }
  116. TENSOR_SLICE_ALIGNMENT_SIZE = align_size;
  117. MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << ".";
  118. // FULLY_USE_DEVICES
  119. auto fully_devices = CostModelContext::GetInstance()->fully_use_device();
  120. FULLY_USE_DEVICES = fully_devices;
  121. if (FULLY_USE_DEVICES) {
  122. MS_LOG(INFO) << "fully_use_devices: true.";
  123. } else {
  124. MS_LOG(INFO) << "fully_use_devices: false.";
  125. }
  126. // ELEMENTWISE_OP_STRA_FOLLOW
  127. auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
  128. ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow;
  129. if (ELEMENTWISE_OP_STRA_FOLLOW) {
  130. MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
  131. } else {
  132. MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
  133. }
  134. // MULTI_SUBGRAPHS
  135. auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
  136. MULTI_SUBGRAPHS = multi_subgraphs;
  137. if (MULTI_SUBGRAPHS) {
  138. MS_LOG(INFO) << "multi_subgraphs: true.";
  139. } else {
  140. MS_LOG(INFO) << "multi_subgraphs: false.";
  141. }
  142. // RUN_PHASE
  143. auto phase = CostModelContext::GetInstance()->run_phase();
  144. if (phase != 0 && phase != 1) {
  145. MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
  146. }
  147. RUN_PHASE = phase;
  148. MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
  149. }
  150. void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
  151. for (auto it = ops_.begin(); it != ops_.end();) {
  152. if ((*it) == op) {
  153. it = ops_.erase(it);
  154. } else {
  155. ++it;
  156. }
  157. }
  158. }
  159. bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
  160. struct IsInGraph {
  161. const OperatorInfoPtr test_;
  162. explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
  163. bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
  164. };
  165. return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
  166. }
  167. void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
  168. std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
  169. curr_edges.push_back(edge);
  170. edges_[{u_node, v_node}] = curr_edges;
  171. std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
  172. curr_out_edges.push_back(edge);
  173. out_edges_[u_node] = curr_out_edges;
  174. std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
  175. curr_in_edges.push_back(edge);
  176. in_edges_[v_node] = curr_in_edges;
  177. }
  178. bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
  179. for (auto &edge_pair : edges_) {
  180. auto edges = edge_pair.second;
  181. for (auto &edge : edges) {
  182. MS_EXCEPTION_IF_NULL(edge);
  183. bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
  184. (edge->next_op_input_index() == input_index);
  185. if (bool_result) {
  186. return true;
  187. }
  188. }
  189. }
  190. return false;
  191. }
  192. std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
  193. std::vector<OperatorInfoPtr> alive_ops) {
  194. std::map<OperatorInfoPtr, bool> visited;
  195. for (auto &op : alive_ops) {
  196. visited[op] = false;
  197. }
  198. MS_LOG(INFO) << "visited: " << visited.size() << ".";
  199. for (auto &op : alive_ops) {
  200. if ((!visited[op]) && op->is_alive()) {
  201. std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
  202. MS_EXCEPTION_IF_NULL(new_component);
  203. new_component->SetDeviceMemoryAndCostParameter();
  204. DFS(op, &visited, new_component);
  205. connected_compoents_.push_back(new_component);
  206. }
  207. }
  208. return connected_compoents_;
  209. }
  210. void CostGraph::DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
  211. const std::shared_ptr<CostGraph> &component) {
  212. MS_EXCEPTION_IF_NULL(visited);
  213. MS_EXCEPTION_IF_NULL(component);
  214. visited->at(current_op) = true;
  215. component->AddOperator(current_op);
  216. for (auto &edge : current_op->succ_edges()) {
  217. bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
  218. (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
  219. if (bool_test) {
  220. component->AddEdge(current_op, edge->next_operator(), edge);
  221. DFS(edge->next_operator(), visited, component);
  222. }
  223. }
  224. for (auto &edge : current_op->prev_edges()) {
  225. bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
  226. (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
  227. if (bool_test) {
  228. component->AddEdge(edge->prev_operator(), current_op, edge);
  229. DFS(edge->prev_operator(), visited, component);
  230. }
  231. }
  232. }
  233. // Create final cost list for the graph: u --> v
  234. CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
  235. const OperatorInfoPtr &v) {
  236. MS_EXCEPTION_IF_NULL(u);
  237. MS_EXCEPTION_IF_NULL(v);
  238. MS_EXCEPTION_IF_NULL(e);
  239. CostPtrList ret;
  240. for (const auto &u_strategy : u->GetStrategyCost()) {
  241. for (const auto &v_strategy : v->GetStrategyCost()) {
  242. MS_EXCEPTION_IF_NULL(u_strategy);
  243. MS_EXCEPTION_IF_NULL(v_strategy);
  244. auto u_strategy_ptr = u_strategy->strategy_ptr;
  245. auto v_strategy_ptr = v_strategy->strategy_ptr;
  246. CostPtrList clist1 = u_strategy->cost_list;
  247. CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
  248. CostPtrList clist3 = v_strategy->cost_list;
  249. for (const auto &cost1 : clist1) {
  250. for (const auto &cost2 : clist2) {
  251. for (const auto &cost3 : clist3) {
  252. MS_EXCEPTION_IF_NULL(cost1);
  253. MS_EXCEPTION_IF_NULL(cost2);
  254. MS_EXCEPTION_IF_NULL(cost3);
  255. double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
  256. double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
  257. double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
  258. double communication_forward =
  259. cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
  260. double communication_without_para = cost1->communication_without_parameter_ +
  261. cost2->communication_without_parameter_ +
  262. cost3->communication_without_parameter_;
  263. auto decision =
  264. std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
  265. auto cost = std::make_shared<Cost>(computation, communication, decision);
  266. MS_EXCEPTION_IF_NULL(cost);
  267. cost->communication_without_parameter_ = communication_without_para;
  268. cost->communication_with_partial_para_ =
  269. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  270. cost->memory_with_reuse_ = memory;
  271. cost->communication_forward_ = communication_forward;
  272. ret.push_back(cost);
  273. }
  274. }
  275. }
  276. }
  277. }
  278. Simplify(&ret);
  279. return ret;
  280. }
  281. // Create final cost list for the graph containing a signle node: u
  282. CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
  283. MS_EXCEPTION_IF_NULL(u);
  284. CostPtrList ret;
  285. for (const auto &u_strategy : u->GetStrategyCost()) {
  286. MS_EXCEPTION_IF_NULL(u_strategy);
  287. auto u_strategy_ptr = u_strategy->strategy_ptr;
  288. CostPtrList clist1 = u_strategy->cost_list;
  289. for (const auto &cost1 : clist1) {
  290. MS_EXCEPTION_IF_NULL(cost1);
  291. auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
  292. auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
  293. MS_EXCEPTION_IF_NULL(new_cost);
  294. new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
  295. new_cost->communication_with_partial_para_ =
  296. cost1->communication_without_parameter_ +
  297. COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
  298. new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
  299. new_cost->communication_forward_ = cost1->communication_forward_;
  300. ret.push_back(new_cost);
  301. }
  302. }
  303. Simplify(&ret);
  304. return ret;
  305. }
  306. CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
  307. // Select the cost with minimum inference time. Currently, the inference time is modeled as =
  308. // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
  309. if (cost_list.empty()) {
  310. MS_LOG(ERROR) << "Final cost list is null.";
  311. return nullptr;
  312. }
  313. CostPtrList after_mem_filter;
  314. double minimum_memory = DBL_MAX;
  315. // Filter out the valid costs.
  316. for (auto &a_cost : cost_list) {
  317. if (a_cost->memory_with_reuse_ <= memory) {
  318. after_mem_filter.emplace_back(std::move(a_cost));
  319. } else if (a_cost->memory_with_reuse_ < minimum_memory) {
  320. minimum_memory = a_cost->memory_with_reuse_;
  321. }
  322. }
  323. if (after_mem_filter.empty()) {
  324. MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
  325. << ", the memory capacity is: " << memory << ".";
  326. return nullptr;
  327. }
  328. // Init the returned value with first cost.
  329. CostPtr ret = after_mem_filter[0];
  330. double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
  331. MS_LOG(INFO) << "Cost 0: "
  332. << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
  333. << ", communication_forward_: " << ret->communication_forward_
  334. << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
  335. << ", communication_cost_: " << ret->communication_cost_
  336. << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
  337. MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
  338. for (size_t i = 1; i < after_mem_filter.size(); ++i) {
  339. MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
  340. MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
  341. << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
  342. << ", communication_forward_: " << after_mem_filter[i]->communication_forward_
  343. << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
  344. << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
  345. << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
  346. << ".";
  347. auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
  348. costmodel_beta_ * after_mem_filter[i]->communication_forward_;
  349. MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
  350. if (minimum > tmp) {
  351. minimum = tmp;
  352. ret = after_mem_filter[i];
  353. MS_LOG(INFO) << "Selected: " << i;
  354. }
  355. }
  356. return ret;
  357. }
  358. CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
  359. // Select the cost with minimum training time. Currently, the training time is modeled as =
  360. // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
  361. if (cost_list.empty()) {
  362. MS_LOG(ERROR) << "Final cost list is null.";
  363. return nullptr;
  364. }
  365. CostPtrList after_mem_filter;
  366. double minimum_memory = DBL_MAX;
  367. // Filter out the valid costs.
  368. for (auto &a_cost : cost_list) {
  369. if (a_cost->memory_with_reuse_ <= memory) {
  370. after_mem_filter.emplace_back(std::move(a_cost));
  371. } else if (a_cost->memory_with_reuse_ < minimum_memory) {
  372. minimum_memory = a_cost->memory_with_reuse_;
  373. }
  374. }
  375. if (after_mem_filter.empty()) {
  376. MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
  377. << ", the memory capacity is: " << memory << ".";
  378. return nullptr;
  379. }
  380. // Init the returned value with first cost.
  381. CostPtr ret = after_mem_filter[0];
  382. double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
  383. MS_LOG(INFO) << "Cost 0: "
  384. << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
  385. << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
  386. << ", communication_cost_: " << ret->communication_cost_
  387. << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
  388. MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
  389. for (size_t i = 1; i < after_mem_filter.size(); ++i) {
  390. MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
  391. MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
  392. << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
  393. << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
  394. << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
  395. << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
  396. << ".";
  397. auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
  398. costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
  399. MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
  400. if (minimum > tmp) {
  401. minimum = tmp;
  402. ret = after_mem_filter[i];
  403. MS_LOG(INFO) << "Selected: " << i;
  404. }
  405. }
  406. return ret;
  407. }
  408. CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
  409. double available_memory) {
  410. CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
  411. double minimum = DBL_MAX, total_memory = 0.0;
  412. CostPtrList ret(all_cost_list.size(), nullptr);
  413. // Check whether valid costs exist.
  414. for (size_t i = 0; i < all_cost_list.size(); ++i) {
  415. if (all_cost_list[i][0] == nullptr) {
  416. MS_LOG(ERROR) << "The cost list " << i << " is empty.";
  417. return ret;
  418. } else {
  419. double memory_i_cost = DBL_MAX;
  420. for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
  421. if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
  422. memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
  423. }
  424. }
  425. total_memory += memory_i_cost;
  426. }
  427. }
  428. if (total_memory >= available_memory) {
  429. MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
  430. << ", minimum strategy cost: " << total_memory << ".";
  431. return selected_cost_list;
  432. }
  433. std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
  434. &available_memory, this](size_t k) {
  435. if (k == all_cost_list.size()) {
  436. double tmp_memory = 0.0, tmp_minimum = 0.0;
  437. for (size_t i = 0; i < selected_cost_list.size(); ++i) {
  438. MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
  439. tmp_memory += selected_cost_list[i]->memory_with_reuse_;
  440. tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
  441. costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
  442. }
  443. MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
  444. << ".";
  445. if (tmp_memory < available_memory && tmp_minimum < minimum) {
  446. ret = selected_cost_list;
  447. minimum = tmp_minimum;
  448. MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
  449. }
  450. return;
  451. }
  452. MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
  453. for (auto &c : all_cost_list[k]) {
  454. selected_cost_list[k] = c;
  455. recursive(k + 1);
  456. }
  457. };
  458. recursive(0);
  459. return ret;
  460. }
  461. Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
  462. MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
  463. auto connected_components = ConstructConnectedComponents(alive_ops);
  464. MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
  465. std::vector<CostPtrList> all_list;
  466. for (size_t j = 0; j < connected_components.size(); ++j) {
  467. auto one_component = connected_components[j];
  468. MS_EXCEPTION_IF_NULL(one_component);
  469. if (one_component->GetOperators().size() == 1) {
  470. MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
  471. auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
  472. all_list.push_back(cost_list);
  473. } else if (one_component->GetOperators().size() == 2) {
  474. MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
  475. OperatorInfoPtr u, v;
  476. auto first_op = one_component->GetOperators()[0];
  477. auto second_op = one_component->GetOperators()[1];
  478. MS_EXCEPTION_IF_NULL(first_op);
  479. MS_EXCEPTION_IF_NULL(second_op);
  480. if (!first_op->GetAliveSuccEdges().empty() &&
  481. first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
  482. u = first_op;
  483. v = second_op;
  484. } else if (!second_op->GetAliveSuccEdges().empty() &&
  485. second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
  486. u = second_op;
  487. v = first_op;
  488. } else {
  489. MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
  490. << ", " << second_op->GetAliveSuccEdges().size() << ".";
  491. }
  492. MS_EXCEPTION_IF_NULL(u);
  493. auto e = u->GetAliveSuccEdges()[0];
  494. auto cost_list = one_component->CreateFinalCostList(u, e, v);
  495. all_list.push_back(cost_list);
  496. } else {
  497. MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
  498. << " operators in a component in the final graph.";
  499. }
  500. }
  501. //
  502. auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
  503. for (size_t k = 0; k < selected_cost_list.size(); ++k) {
  504. auto selected_cost = selected_cost_list[k];
  505. if (selected_cost == nullptr) {
  506. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  507. return FAILED;
  508. }
  509. MS_EXCEPTION_IF_NULL(connected_components[k]);
  510. if (connected_components[k]->GetOperators().size() == 1) {
  511. auto u = connected_components[k]->GetOperators()[0];
  512. auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  513. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  514. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  515. } else if (connected_components[k]->GetOperators().size() == 2) {
  516. OperatorInfoPtr u = nullptr, v = nullptr;
  517. auto first_op = connected_components[k]->GetOperators()[0];
  518. auto second_op = connected_components[k]->GetOperators()[1];
  519. MS_EXCEPTION_IF_NULL(first_op);
  520. MS_EXCEPTION_IF_NULL(second_op);
  521. if (!first_op->GetAliveSuccEdges().empty() &&
  522. first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
  523. u = first_op;
  524. v = second_op;
  525. } else if (!second_op->GetAliveSuccEdges().empty() &&
  526. second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
  527. u = second_op;
  528. v = first_op;
  529. }
  530. MS_EXCEPTION_IF_NULL(u);
  531. auto e = u->GetAliveSuccEdges()[0];
  532. MS_EXCEPTION_IF_NULL(v);
  533. MS_EXCEPTION_IF_NULL(e);
  534. MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
  535. auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
  536. MS_EXCEPTION_IF_NULL(decision);
  537. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
  538. v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
  539. e->set_selected_cost(decision->middle_cost_);
  540. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  541. }
  542. }
  543. return SUCCESS;
  544. }
  545. // searching the strategy for the final eliminated graph
  546. Status CostGraph::SearchStrategy() {
  547. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
  548. std::vector<OperatorInfoPtr> alive_ops;
  549. (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
  550. MS_EXCEPTION_IF_NULL(op);
  551. if (op->is_alive()) {
  552. alive_ops.push_back(op);
  553. }
  554. });
  555. if (alive_ops.size() > 2) {
  556. if (RUN_PHASE == TRAINING_PHASE) {
  557. // training phase
  558. return SearchStrategyForMultiNodeFinalGraph(alive_ops);
  559. } else {
  560. // inference phase
  561. MS_LOG(EXCEPTION)
  562. << "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
  563. }
  564. } else if (alive_ops.size() == 1) {
  565. MS_LOG(INFO) << "There are 1 single node in the final graph.";
  566. OperatorInfoPtr u = alive_ops[0];
  567. auto cost_list = CreateFinalSingleCostList(u);
  568. CostPtr cost = nullptr;
  569. if (RUN_PHASE == TRAINING_PHASE) {
  570. // training phase
  571. cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
  572. } else {
  573. // inference phase
  574. cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
  575. }
  576. if (cost == nullptr) {
  577. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  578. return FAILED;
  579. }
  580. MS_EXCEPTION_IF_NULL(u);
  581. MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
  582. auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  583. MS_EXCEPTION_IF_NULL(decision);
  584. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  585. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
  586. return SUCCESS;
  587. } else {
  588. // In this case, the final graph should contains exactly 2 nodes.
  589. if (alive_ops.empty()) {
  590. MS_LOG(INFO) << "0 Operator in the final graph.";
  591. return SUCCESS;
  592. }
  593. OperatorInfoPtr u, v;
  594. MS_EXCEPTION_IF_NULL(alive_ops[0]);
  595. MS_EXCEPTION_IF_NULL(alive_ops[1]);
  596. if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
  597. alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
  598. u = alive_ops[0];
  599. v = alive_ops[1];
  600. } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
  601. alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
  602. u = alive_ops[1];
  603. v = alive_ops[0];
  604. } else {
  605. if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
  606. MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
  607. << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
  608. } else {
  609. // In this case, the final graph consists of two single nodes
  610. MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
  611. std::vector<CostPtrList> all_list;
  612. auto connected_components = ConstructConnectedComponents(alive_ops);
  613. MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
  614. for (size_t i = 0; i < connected_components.size(); ++i) {
  615. MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
  616. auto one_component = connected_components[i];
  617. MS_EXCEPTION_IF_NULL(one_component);
  618. auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
  619. all_list.push_back(cost_list);
  620. }
  621. CostPtrList selected_cost_list;
  622. if (RUN_PHASE == TRAINING_PHASE) {
  623. // training phase
  624. selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
  625. } else {
  626. // inference phase
  627. MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
  628. "phase is not supported.";
  629. }
  630. for (size_t k = 0; k < selected_cost_list.size(); ++k) {
  631. auto selected_cost = selected_cost_list[k];
  632. if (selected_cost == nullptr) {
  633. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  634. return FAILED;
  635. }
  636. MS_EXCEPTION_IF_NULL(connected_components[k]);
  637. auto one_operator = connected_components[k]->GetOperators()[0];
  638. MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
  639. auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  640. MS_EXCEPTION_IF_NULL(decision);
  641. one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  642. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  643. }
  644. return SUCCESS;
  645. }
  646. }
  647. MS_LOG(INFO) << "There are 2 nodes in the final graph.";
  648. // In this case, the finale graph is exactly of the form: u --> v
  649. MS_EXCEPTION_IF_NULL(u);
  650. MS_EXCEPTION_IF_NULL(v);
  651. auto e = u->GetAliveSuccEdges()[0];
  652. MS_EXCEPTION_IF_NULL(e);
  653. auto cost_list = CreateFinalCostList(u, e, v);
  654. CostPtr cost = nullptr;
  655. if (RUN_PHASE == TRAINING_PHASE) {
  656. // training phase
  657. cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
  658. } else {
  659. MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
  660. "phase is not supported.";
  661. }
  662. if (cost == nullptr) {
  663. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  664. return FAILED;
  665. }
  666. MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
  667. auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
  668. MS_EXCEPTION_IF_NULL(decision);
  669. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
  670. v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
  671. e->set_selected_cost(decision->middle_cost_);
  672. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
  673. return SUCCESS;
  674. }
  675. }
  676. // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
  677. // return the v and the edge u --> v
  678. OperatorInfoPtr CostGraph::CheckOpElimination() const {
  679. for (auto &op : ops_) {
  680. bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
  681. if (bool_test) {
  682. if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
  683. return op;
  684. }
  685. }
  686. }
  687. return nullptr;
  688. }
  689. // Check the graph whether an EdgeElimination can be performed
  690. std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
  691. for (auto &op : ops_) {
  692. MS_EXCEPTION_IF_NULL(op);
  693. if (!op->is_alive()) continue;
  694. std::map<void *, int> count;
  695. for (auto &edge : op->GetAliveSuccEdges()) {
  696. MS_EXCEPTION_IF_NULL(edge);
  697. auto v = edge->next_operator();
  698. count[v.get()]++;
  699. }
  700. for (auto &pair : count) {
  701. auto *op_ptr = pair.first;
  702. int op_count = pair.second;
  703. if (op_count > 1) {
  704. std::vector<std::shared_ptr<Edge>> ret;
  705. for (auto &edge : op->GetAliveSuccEdges()) {
  706. MS_EXCEPTION_IF_NULL(edge);
  707. if (edge->next_operator().get() == op_ptr) {
  708. ret.push_back(edge);
  709. }
  710. }
  711. return ret;
  712. }
  713. }
  714. }
  715. return {};
  716. }
  717. // Check the graph whether a MergeElimination can be performed
  718. OperatorInfoPtr CostGraph::CheckMergeElimination() const {
  719. for (auto &op : ops_) {
  720. MS_EXCEPTION_IF_NULL(op);
  721. bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
  722. if (bool_test) {
  723. auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
  724. MS_EXCEPTION_IF_NULL(next_op);
  725. if (!next_op->GetAlivePrevEdges().empty()) {
  726. return op;
  727. }
  728. }
  729. }
  730. return nullptr;
  731. }
  732. // Check the graph whether a ContractElimination can be performed
  733. OperatorInfoPtr CostGraph::CheckContractElimination() const {
  734. for (auto &op : ops_) {
  735. MS_EXCEPTION_IF_NULL(op);
  736. bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
  737. if (bool_test) {
  738. auto edge = op->GetAlivePrevEdges()[0];
  739. MS_EXCEPTION_IF_NULL(edge);
  740. auto prev_op = edge->prev_operator();
  741. MS_EXCEPTION_IF_NULL(prev_op);
  742. if (!prev_op->GetAliveSuccEdges().empty()) {
  743. return op;
  744. }
  745. }
  746. }
  747. return nullptr;
  748. }
  749. // Check the graph whether a TriangleElimination can be performed
  750. std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
  751. for (auto &op : ops_) {
  752. MS_EXCEPTION_IF_NULL(op);
  753. bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
  754. if (bool_test) {
  755. auto edge1 = op->GetAliveSuccEdges()[0];
  756. auto edge2 = op->GetAliveSuccEdges()[1];
  757. MS_EXCEPTION_IF_NULL(edge1);
  758. MS_EXCEPTION_IF_NULL(edge2);
  759. auto first_op = edge1->next_operator();
  760. auto second_op = edge2->next_operator();
  761. MS_EXCEPTION_IF_NULL(first_op);
  762. for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
  763. if (first_op_succ_edge->next_operator() == second_op) {
  764. return {op, first_op_succ_edge};
  765. }
  766. }
  767. MS_EXCEPTION_IF_NULL(second_op);
  768. for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
  769. if (second_op_succ_edge->next_operator() == first_op) {
  770. return {op, second_op_succ_edge};
  771. }
  772. }
  773. }
  774. }
  775. return {nullptr, nullptr};
  776. }
  777. // Check the graph whether a StarElimination can be performed.
  778. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
  779. OperatorInfoPtr CostGraph::CheckStarElimination() const {
  780. for (auto &op : ops_) {
  781. MS_EXCEPTION_IF_NULL(op);
  782. bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
  783. if (bool_test) {
  784. return op;
  785. }
  786. }
  787. return nullptr;
  788. }
  789. // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
  790. // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
  791. std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) {
  792. // in this case, the operators are organised in the form of u-->op-->v, and the goal
  793. // is to eliminate 'op'.
  794. MS_EXCEPTION_IF_NULL(op);
  795. MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
  796. auto edge_u_op = op->GetAlivePrevEdges()[0];
  797. auto edge_op_v = op->GetAliveSuccEdges()[0];
  798. MS_EXCEPTION_IF_NULL(edge_u_op);
  799. MS_EXCEPTION_IF_NULL(edge_op_v);
  800. auto u = edge_u_op->prev_operator();
  801. auto v = edge_op_v->next_operator();
  802. std::vector<size_t> output_indexs, input_indexs;
  803. size_t output_index, input_index;
  804. MS_EXCEPTION_IF_NULL(u);
  805. MS_EXCEPTION_IF_NULL(v);
  806. std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
  807. std::shared_ptr<Edge> new_edge;
  808. if (edge_u_op->is_combined()) {
  809. output_indexs = edge_u_op->prev_op_output_indexs();
  810. } else {
  811. output_index = edge_u_op->prev_op_output_index();
  812. output_indexs.push_back(output_index);
  813. }
  814. if (edge_op_v->is_combined()) {
  815. input_indexs = edge_op_v->next_op_input_indexs();
  816. } else {
  817. input_index = edge_op_v->next_op_input_index();
  818. input_indexs.push_back(input_index);
  819. }
  820. if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
  821. new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
  822. } else {
  823. new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
  824. }
  825. MS_EXCEPTION_IF_NULL(new_edge);
  826. new_edge->set_pre_op_output(edge_u_op->prev_op_output());
  827. new_edge->set_next_op_input(edge_op_v->next_op_input());
  828. new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
  829. u->ReplaceSuccEdge(op, new_edge);
  830. v->ReplacePreEdge(op, new_edge);
  831. op->SetNotAlive();
  832. MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
  833. return new_edge;
  834. }
  835. // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
  836. // and sets new costlist for the new edge.
  837. std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) {
  838. MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
  839. MS_EXCEPTION_IF_NULL(edges[0]);
  840. auto u = edges[0]->prev_operator();
  841. auto v = edges[0]->next_operator();
  842. MS_EXCEPTION_IF_NULL(u);
  843. MS_EXCEPTION_IF_NULL(v);
  844. std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
  845. std::vector<size_t> output_indexs, input_indexs;
  846. for (auto &edge : edges) {
  847. MS_EXCEPTION_IF_NULL(edge);
  848. if (edge->is_combined()) {
  849. auto from_output_indexs = edge->prev_op_output_indexs();
  850. auto from_input_indexs = edge->next_op_input_indexs();
  851. (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
  852. (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
  853. } else {
  854. output_indexs.push_back(edge->prev_op_output_index());
  855. input_indexs.push_back(edge->next_op_input_index());
  856. }
  857. }
  858. std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
  859. MS_EXCEPTION_IF_NULL(new_edge);
  860. new_edge->set_pre_op_output(edges[0]->prev_op_output());
  861. new_edge->set_next_op_input(edges[0]->next_op_input());
  862. new_edge->EdgeEliminationSetNewCost(u, edges, v);
  863. u->ReplaceSuccEdges(v, new_edge);
  864. v->ReplacePreEdges(u, new_edge);
  865. MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
  866. return new_edge;
  867. }
  868. // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
  869. // for this contract under the strategy 'op_strategy'
  870. void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
  871. const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
  872. const CostPtrList &tar_cost_list,
  873. CostPtrList *const tar_cost_list_new) {
  874. for (size_t i = 0; i < op_cost_list.size(); ++i) {
  875. auto &op_cost = op_cost_list[i];
  876. MS_EXCEPTION_IF_NULL(op_cost);
  877. for (size_t j = 0; j < edge_cost_list.size(); ++j) {
  878. auto &edge_cost = edge_cost_list[j];
  879. MS_EXCEPTION_IF_NULL(edge_cost);
  880. for (size_t k = 0; k < tar_cost_list.size(); ++k) {
  881. auto &tar_cost = tar_cost_list[k];
  882. MS_EXCEPTION_IF_NULL(tar_cost);
  883. double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
  884. double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
  885. double communication =
  886. op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
  887. double communication_forward =
  888. op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
  889. double communication_without_para = op_cost->communication_without_parameter_ +
  890. edge_cost->communication_without_parameter_ +
  891. tar_cost->communication_without_parameter_;
  892. auto decision =
  893. std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
  894. auto new_cost = std::make_shared<Cost>(computation, communication, decision);
  895. MS_EXCEPTION_IF_NULL(new_cost);
  896. new_cost->communication_without_parameter_ = communication_without_para;
  897. new_cost->communication_with_partial_para_ =
  898. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  899. new_cost->memory_with_reuse_ = memory;
  900. new_cost->communication_forward_ = communication_forward;
  901. MS_EXCEPTION_IF_NULL(tar_cost_list_new);
  902. tar_cost_list_new->emplace_back(std::move(new_cost));
  903. }
  904. }
  905. }
  906. }
  907. // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
  908. // target_op
  909. OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
  910. MS_EXCEPTION_IF_NULL(op);
  911. auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
  912. auto edge_ptr = op->GetAliveSuccEdges()[0];
  913. MS_EXCEPTION_IF_NULL(target_op);
  914. MS_EXCEPTION_IF_NULL(edge_ptr);
  915. MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
  916. bool valid = false;
  917. for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
  918. MS_EXCEPTION_IF_NULL(tar_stra_cost);
  919. auto tar_stra = tar_stra_cost->strategy_ptr;
  920. auto tar_clist_origin = tar_stra_cost->cost_list;
  921. CostPtrList tar_clist_new;
  922. for (auto &op_stra_cost : op->GetStrategyCost()) {
  923. MS_EXCEPTION_IF_NULL(op_stra_cost);
  924. auto op_stra = op_stra_cost->strategy_ptr;
  925. auto op_clist = op_stra_cost->cost_list;
  926. auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
  927. CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
  928. }
  929. Simplify(&tar_clist_new);
  930. // Set the new costlist w.r.t the strategy
  931. tar_stra_cost->cost_list = tar_clist_new;
  932. if ((!valid) && (!tar_clist_new.empty())) {
  933. valid = true;
  934. }
  935. }
  936. if (!valid) {
  937. MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
  938. }
  939. op->SetNotAlive();
  940. MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
  941. return target_op;
  942. }
  943. // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
  944. // for this contract under the strategy 'contract_op_stra'
  945. void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
  946. const CostPtrList &contract_op_cost_list,
  947. const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
  948. const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) {
  949. for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
  950. auto &contract_op_cost = contract_op_cost_list[i];
  951. MS_EXCEPTION_IF_NULL(contract_op_cost);
  952. for (size_t j = 0; j < edge_cost_list.size(); ++j) {
  953. auto &edge_cost = edge_cost_list[j];
  954. MS_EXCEPTION_IF_NULL(edge_cost);
  955. for (size_t k = 0; k < tar_cost_list.size(); ++k) {
  956. auto &tar_cost = tar_cost_list[k];
  957. MS_EXCEPTION_IF_NULL(tar_cost);
  958. double computation =
  959. contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
  960. double memory =
  961. contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
  962. double communication =
  963. contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
  964. double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
  965. tar_cost->communication_forward_;
  966. double communication_without_para = contract_op_cost->communication_without_parameter_ +
  967. edge_cost->communication_without_parameter_ +
  968. tar_cost->communication_without_parameter_;
  969. auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
  970. target_op_stra, tar_cost);
  971. auto new_cost = std::make_shared<Cost>(computation, communication, decision);
  972. new_cost->communication_without_parameter_ = communication_without_para;
  973. new_cost->communication_with_partial_para_ =
  974. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  975. new_cost->memory_with_reuse_ = memory;
  976. new_cost->communication_forward_ = communication_forward;
  977. tar_cost_list_new->emplace_back(std::move(new_cost));
  978. }
  979. }
  980. }
  981. }
  982. // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
  983. // target_op
  984. OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
  985. MS_EXCEPTION_IF_NULL(op);
  986. auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
  987. auto edge_ptr = op->GetAlivePrevEdges()[0];
  988. MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
  989. bool valid = false;
  990. for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
  991. MS_EXCEPTION_IF_NULL(tar_stra_cost);
  992. auto tar_stra = tar_stra_cost->strategy_ptr;
  993. auto tar_clist_origin = tar_stra_cost->cost_list;
  994. CostPtrList tar_clist_new;
  995. for (auto &op_stra_cost : op->GetStrategyCost()) {
  996. MS_EXCEPTION_IF_NULL(op_stra_cost);
  997. auto op_stra = op_stra_cost->strategy_ptr;
  998. auto op_clist = op_stra_cost->cost_list;
  999. auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
  1000. CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
  1001. }
  1002. Simplify(&tar_clist_new);
  1003. // Set the new costlist w.r.t the strategy
  1004. tar_stra_cost->cost_list = tar_clist_new;
  1005. if ((!valid) && (!tar_clist_new.empty())) {
  1006. valid = true;
  1007. }
  1008. }
  1009. if (!valid) {
  1010. MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
  1011. }
  1012. op->SetNotAlive();
  1013. MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
  1014. return target_op;
  1015. }
  1016. void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
  1017. StrategyPtr right_op_stra, const CostPtr &right_op_cost,
  1018. const CostPtrList &elimi_op_clist,
  1019. const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
  1020. const CostPtrList &left_node_clist_origin,
  1021. CostPtrList *left_node_clist_new) {
  1022. MS_EXCEPTION_IF_NULL(right_edge_cost);
  1023. MS_EXCEPTION_IF_NULL(right_op_cost);
  1024. MS_EXCEPTION_IF_NULL(left_node_clist_new);
  1025. for (auto &elimi_op_cost : elimi_op_clist) {
  1026. MS_EXCEPTION_IF_NULL(elimi_op_cost);
  1027. for (auto &left_edge_cost : left_edge_clist) {
  1028. MS_EXCEPTION_IF_NULL(left_edge_cost);
  1029. for (auto &left_node_cost : left_node_clist_origin) {
  1030. MS_EXCEPTION_IF_NULL(left_node_cost);
  1031. double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
  1032. left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
  1033. double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
  1034. left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
  1035. double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
  1036. left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
  1037. double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
  1038. left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
  1039. double new_commu_without =
  1040. elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
  1041. left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
  1042. auto decision = std::make_shared<TriangleEliminationDecision>(
  1043. elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
  1044. auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
  1045. new_cost->communication_without_parameter_ = new_commu_without;
  1046. new_cost->communication_with_partial_para_ =
  1047. new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
  1048. new_cost->memory_with_reuse_ = new_memory;
  1049. new_cost->communication_forward_ = new_commu_forward;
  1050. left_node_clist_new->emplace_back(std::move(new_cost));
  1051. }
  1052. }
  1053. }
  1054. }
  1055. void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
  1056. const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
  1057. const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
  1058. const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
  1059. const CostPtrList &left_node_clist_origin,
  1060. CostPtrList *left_node_clist_new) {
  1061. MS_EXCEPTION_IF_NULL(elimi_op);
  1062. for (auto &right_node_cost : right_node_clist) {
  1063. MS_EXCEPTION_IF_NULL(right_node_cost);
  1064. for (auto &right_edge_cost : right_edge_clist) {
  1065. MS_EXCEPTION_IF_NULL(right_edge_cost);
  1066. CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
  1067. elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
  1068. left_node_clist_new);
  1069. }
  1070. }
  1071. }
  1072. OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
  1073. const std::shared_ptr<Edge> &edge_left_right) {
  1074. MS_EXCEPTION_IF_NULL(edge_left_right);
  1075. MS_EXCEPTION_IF_NULL(elimi_op);
  1076. MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
  1077. auto left_node = edge_left_right->prev_operator();
  1078. auto right_node = edge_left_right->next_operator();
  1079. auto left_edge = elimi_op->GetAliveSuccEdges()[0];
  1080. auto right_edge = elimi_op->GetAliveSuccEdges()[1];
  1081. MS_EXCEPTION_IF_NULL(left_node);
  1082. MS_EXCEPTION_IF_NULL(right_node);
  1083. MS_EXCEPTION_IF_NULL(left_edge);
  1084. MS_EXCEPTION_IF_NULL(right_edge);
  1085. MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
  1086. MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
  1087. if (left_edge->next_operator() != left_node) {
  1088. auto tmp = left_edge;
  1089. left_edge = right_edge;
  1090. right_edge = tmp;
  1091. }
  1092. bool valid = false;
  1093. for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
  1094. MS_EXCEPTION_IF_NULL(left_node_stra_cost);
  1095. auto left_node_stra = left_node_stra_cost->strategy_ptr;
  1096. auto left_node_clist_origin = left_node_stra_cost->cost_list;
  1097. CostPtrList left_node_clist_new;
  1098. for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
  1099. MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
  1100. auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
  1101. auto elimi_op_clist = elimi_op_stra_cost->cost_list;
  1102. auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
  1103. for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
  1104. MS_EXCEPTION_IF_NULL(right_node_stra_cost);
  1105. auto right_node_stra = right_node_stra_cost->strategy_ptr;
  1106. auto right_node_clist = right_node_stra_cost->cost_list;
  1107. auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
  1108. CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
  1109. right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
  1110. &left_node_clist_new);
  1111. }
  1112. }
  1113. Simplify(&left_node_clist_new);
  1114. // Set the new costlist w.r.t the strategy
  1115. left_node_stra_cost->cost_list = left_node_clist_new;
  1116. if ((!valid) && (!left_node_clist_new.empty())) {
  1117. valid = true;
  1118. }
  1119. }
  1120. if (!valid) {
  1121. MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed.";
  1122. }
  1123. elimi_op->SetNotAlive();
  1124. MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
  1125. return left_node;
  1126. }
  1127. void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
  1128. const CostPtrList &first_succ_node_clist,
  1129. const CostPtrList &first_succ_edge_clist,
  1130. const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
  1131. std::vector<StrategyPtr> succ_nodes_stras,
  1132. CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
  1133. CostPtrList *first_succ_node_clist_new) {
  1134. for (auto &first_succ_node_cost : first_succ_node_clist) {
  1135. for (auto &first_succ_edge_cost : first_succ_edge_clist) {
  1136. for (auto &merged_node_cost : merged_op_clist) {
  1137. MS_EXCEPTION_IF_NULL(merged_node_cost);
  1138. succ_nodes_stras[0] = first_succ_node_stra;
  1139. succ_edges_costs[0] = first_succ_edge_cost;
  1140. succ_nodes_costs[0] = first_succ_node_cost;
  1141. double computation_cost = merged_node_cost->computation_cost_,
  1142. memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
  1143. commu_without = merged_node_cost->communication_without_parameter_,
  1144. commu_forward = merged_node_cost->communication_forward_;
  1145. for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
  1146. MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
  1147. if (i == 0) {
  1148. computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
  1149. memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
  1150. commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
  1151. commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
  1152. commu_without += succ_edges_costs[i]->communication_without_parameter_ +
  1153. succ_nodes_costs[i]->communication_without_parameter_;
  1154. } else {
  1155. computation_cost += succ_edges_costs[i]->computation_cost_;
  1156. memory_cost += succ_edges_costs[i]->memory_with_reuse_;
  1157. commu_cost += succ_edges_costs[i]->communication_cost_;
  1158. commu_forward += succ_edges_costs[i]->communication_forward_;
  1159. commu_without += succ_edges_costs[i]->communication_without_parameter_;
  1160. }
  1161. }
  1162. auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
  1163. succ_nodes_stras, succ_nodes_costs);
  1164. auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
  1165. new_cost->communication_without_parameter_ = commu_without;
  1166. new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
  1167. new_cost->memory_with_reuse_ = memory_cost;
  1168. new_cost->communication_forward_ = commu_forward;
  1169. first_succ_node_clist_new->emplace_back(std::move(new_cost));
  1170. }
  1171. }
  1172. }
  1173. }
  1174. void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
  1175. const StrategyPtr &first_succ_node_stra,
  1176. const CostPtrList &first_succ_node_clist,
  1177. const CostPtrList &first_succ_edge_clist,
  1178. const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
  1179. CostPtrList *first_succ_node_clist_new) {
  1180. std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
  1181. CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
  1182. std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
  1183. &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
  1184. &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
  1185. this](size_t k) {
  1186. if (k == succ_edges.size()) {
  1187. CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
  1188. merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
  1189. succ_nodes_costs, first_succ_node_clist_new);
  1190. return;
  1191. }
  1192. MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
  1193. << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
  1194. << ", merged_op_clist: " << merged_op_clist.size()
  1195. << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
  1196. auto succ_edge = succ_edges[k];
  1197. MS_EXCEPTION_IF_NULL(succ_edge);
  1198. auto succ_node = succ_edge->next_operator();
  1199. MS_EXCEPTION_IF_NULL(succ_node);
  1200. for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
  1201. MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
  1202. auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
  1203. auto succ_node_clist = succ_node_stra_cost->cost_list;
  1204. auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
  1205. for (auto &succ_node_cost : succ_node_clist) {
  1206. MS_EXCEPTION_IF_NULL(succ_node_cost);
  1207. for (auto &succ_edge_cost : succ_edge_clist) {
  1208. MS_EXCEPTION_IF_NULL(succ_edge_cost);
  1209. succ_nodes_stras[k] = succ_node_stra;
  1210. succ_edges_costs[k] = succ_edge_cost;
  1211. succ_nodes_costs[k] = succ_node_cost;
  1212. recursive(k + 1);
  1213. }
  1214. }
  1215. }
  1216. };
  1217. recursive(1);
  1218. }
  1219. std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) {
  1220. MS_EXCEPTION_IF_NULL(merged_op);
  1221. auto succ_edges = merged_op->GetAliveSuccEdges();
  1222. MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
  1223. for (auto &succ_edge : succ_edges) {
  1224. MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
  1225. MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
  1226. }
  1227. MS_EXCEPTION_IF_NULL(succ_edges[0]);
  1228. auto first_succ_node = succ_edges[0]->next_operator();
  1229. auto first_succ_edge = succ_edges[0];
  1230. bool valid = false;
  1231. // 'merged_op' is merged into first_node
  1232. MS_EXCEPTION_IF_NULL(first_succ_node);
  1233. for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
  1234. MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
  1235. auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
  1236. auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
  1237. CostPtrList first_succ_node_clist_new;
  1238. for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
  1239. MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
  1240. auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
  1241. auto merged_op_clist = merged_op_stra_cost->cost_list;
  1242. auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
  1243. CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
  1244. merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
  1245. }
  1246. Simplify(&first_succ_node_clist_new);
  1247. // Set the new costlist w.r.t the strategy
  1248. first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
  1249. if ((!valid) && (!first_succ_node_clist_new.empty())) {
  1250. valid = true;
  1251. }
  1252. }
  1253. if (!valid) {
  1254. MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed.";
  1255. }
  1256. merged_op->SetNotAlive();
  1257. MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
  1258. return succ_edges;
  1259. }
  1260. size_t CostGraph::GetNumEdges() const {
  1261. size_t sum = 0;
  1262. for (const auto &kv : edges_) {
  1263. auto &edges = kv.second;
  1264. sum += edges.size();
  1265. }
  1266. return sum;
  1267. }
  1268. Status CostGraph::InitSelectedStrategy() {
  1269. for (auto &op : ops_) {
  1270. MS_EXCEPTION_IF_NULL(op);
  1271. if (op->name().find(RESHAPEINFO) != std::string::npos) {
  1272. continue;
  1273. }
  1274. auto result = op->InitSelectedStrategy(op->selected_strategy());
  1275. if (result != SUCCESS) {
  1276. return result;
  1277. }
  1278. }
  1279. // reshape init should be apply after the init of it's previous node and next node.
  1280. for (size_t i = 0; i < ops_.size(); ++i) {
  1281. if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
  1282. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
  1283. auto in_edges = GetOriginalPrevEdges(ops_[i]);
  1284. auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr<Edge> edge) {
  1285. return edge->prev_operator()->name() == reshape_info->pre_operator_name();
  1286. });
  1287. auto out_edges = GetOriginalNextEdges(ops_[i]);
  1288. auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
  1289. return edge->next_operator()->name() == reshape_info->next_operator_name();
  1290. });
  1291. if (pre_iter != in_edges.end()) {
  1292. MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
  1293. int32_t pre_index = reshape_info->pre_operator_index();
  1294. TensorInfo pre_info;
  1295. if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
  1296. pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
  1297. } else {
  1298. pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
  1299. }
  1300. reshape_info->SetInputLayout(pre_info.tensor_layout());
  1301. Dimensions stra = pre_info.InferStrategy();
  1302. if (stra.empty()) {
  1303. MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
  1304. }
  1305. std::vector<Dimensions> stra_inputs = {stra};
  1306. StrategyPtr reshape_stra =
  1307. std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
  1308. reshape_info->set_strategy(reshape_stra);
  1309. }
  1310. if (next_iter != out_edges.end()) {
  1311. MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
  1312. int32_t next_index = reshape_info->next_operator_index();
  1313. reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout());
  1314. }
  1315. return reshape_info->Init(nullptr);
  1316. }
  1317. }
  1318. return SUCCESS;
  1319. }
  1320. Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
  1321. for (auto &op : ops_) {
  1322. MS_EXCEPTION_IF_NULL(op);
  1323. const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
  1324. if ((output_parameter != 0) && (output_parameter != 1)) {
  1325. MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
  1326. return FAILED;
  1327. }
  1328. }
  1329. return SUCCESS;
  1330. }
  1331. void CostGraph::DFSForTopoOrder(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
  1332. std::vector<OperatorInfoPtr> *topo_order) {
  1333. MS_EXCEPTION_IF_NULL(current_op);
  1334. MS_EXCEPTION_IF_NULL(visited);
  1335. MS_EXCEPTION_IF_NULL(topo_order);
  1336. visited->at(current_op) = true;
  1337. for (const auto &s_edge : current_op->succ_edges()) {
  1338. if (!visited->at(s_edge->next_operator())) {
  1339. DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
  1340. }
  1341. }
  1342. topo_order->push_back(current_op);
  1343. }
  1344. // Compute a topological order of the costgraph
  1345. void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
  1346. std::map<OperatorInfoPtr, bool> visited;
  1347. for (auto &op : ops_) {
  1348. visited[op] = false;
  1349. }
  1350. for (auto &op : ops_) {
  1351. if (!visited[op]) {
  1352. DFSForTopoOrder(op, &visited, topo_order);
  1353. }
  1354. }
  1355. }
  1356. void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) {
  1357. for (auto &op : ops_) {
  1358. auto search = candidate_ops.find(op);
  1359. if (search != candidate_ops.end()) {
  1360. // Mark the critical operators
  1361. op->mark_output_critical();
  1362. // Mark the successive edges
  1363. for (auto &s_edge : op->succ_edges()) {
  1364. s_edge->mark_output_critical();
  1365. }
  1366. } else {
  1367. op->mark_output_not_critical();
  1368. }
  1369. }
  1370. }
  1371. Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
  1372. if (topo_order.size() == 0) {
  1373. MS_LOG(ERROR) << "0 operator in costgraph.";
  1374. return FAILED;
  1375. }
  1376. auto &first_op = topo_order[0];
  1377. if (first_op->prev_edges().size() > 0) {
  1378. MS_LOG(ERROR) << "The first operator in the first of topological order of "
  1379. "costgraph should have 0 incoming edge, but has "
  1380. << first_op->prev_edges() << "edges.";
  1381. return FAILED;
  1382. }
  1383. // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
  1384. // of the output of OperatorInfo that currently has not been used
  1385. std::map<OperatorInfoPtr, int> curr_memory_state;
  1386. (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size())));
  1387. std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state;
  1388. // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
  1389. // not been used
  1390. double curr_memory_size = first_op->GetOutputsTotalSize();
  1391. double max_memory_size = curr_memory_size;
  1392. for (size_t finished = 1; finished < topo_order.size(); ++finished) {
  1393. // Produce
  1394. (void)curr_memory_state.emplace(
  1395. std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size())));
  1396. curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
  1397. // Consume
  1398. for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
  1399. const auto &prev_op = prev_edge->prev_operator();
  1400. curr_memory_state[prev_op]--;
  1401. }
  1402. for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
  1403. const auto &prev_op = prev_edge->prev_operator();
  1404. if (curr_memory_state[prev_op] < 0) {
  1405. MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
  1406. return FAILED;
  1407. } else if (curr_memory_state[prev_op] == 0) {
  1408. curr_memory_state.erase(prev_op);
  1409. curr_memory_size -= prev_op->GetOutputsTotalSize();
  1410. }
  1411. }
  1412. if (curr_memory_size < 0) {
  1413. MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
  1414. }
  1415. // Modify the max
  1416. if (curr_memory_size > max_memory_size) {
  1417. max_memory_size = curr_memory_size;
  1418. max_memory_state = curr_memory_state;
  1419. }
  1420. }
  1421. // Mark those critical operators
  1422. MarkCriticalOpsAndEdges(max_memory_state);
  1423. return SUCCESS;
  1424. }
  1425. Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
  1426. // Two steps to do:
  1427. // 1. Compute a topological order of the costgraph
  1428. // 2. Determine and mark the operators (and necessary edges) that are critical
  1429. std::vector<OperatorInfoPtr> topo_order;
  1430. TopologyOrder(&topo_order);
  1431. std::reverse(std::begin(topo_order), std::end(topo_order));
  1432. if (DetermineCriticalOps(topo_order) != SUCCESS) {
  1433. MS_LOG(ERROR) << "Determining critical operators failed.";
  1434. return FAILED;
  1435. }
  1436. return SUCCESS;
  1437. }
  1438. Status CostGraph::CalculateOpsMemoryCost() {
  1439. for (auto &op : ops_) {
  1440. MS_EXCEPTION_IF_NULL(op);
  1441. if (op->CalculateMemoryCost() != SUCCESS) {
  1442. MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
  1443. return FAILED;
  1444. }
  1445. }
  1446. return SUCCESS;
  1447. }
  1448. Status CostGraph::CalculateOpsMemoryCostForInference() {
  1449. for (auto &op : ops_) {
  1450. MS_EXCEPTION_IF_NULL(op);
  1451. if (op->CalculateMemoryCostForInference() != SUCCESS) {
  1452. MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
  1453. return FAILED;
  1454. }
  1455. }
  1456. return SUCCESS;
  1457. }
  1458. Status CostGraph::CalculateEdgesMemoryCost() {
  1459. for (auto &edge_pair : edges_) {
  1460. const auto &edges = edge_pair.second;
  1461. for (auto &one_edge : edges) {
  1462. if (one_edge->CalculateMemoryCost() != SUCCESS) {
  1463. MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
  1464. return FAILED;
  1465. }
  1466. }
  1467. }
  1468. return SUCCESS;
  1469. }
  1470. Status CostGraph::CalculateEdgesMemoryCostForInference() {
  1471. for (auto &edge_pair : edges_) {
  1472. const auto &edges = edge_pair.second;
  1473. for (auto &one_edge : edges) {
  1474. if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
  1475. MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
  1476. return FAILED;
  1477. }
  1478. }
  1479. }
  1480. return SUCCESS;
  1481. }
  1482. OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
  1483. for (auto one_op : ops_) {
  1484. if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
  1485. if (one_op->refkey_parameter_name() == p_name) {
  1486. return one_op;
  1487. }
  1488. }
  1489. }
  1490. return nullptr;
  1491. }
  1492. Status CostGraph::CorrectOpsMemoryCost() {
  1493. for (auto &one_op : ops_) {
  1494. if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
  1495. if (one_op->GetAliveSuccEdges().size() > 1) {
  1496. // Filter out the case when the TmpIdentity being used by multiple operators
  1497. std::map<size_t, int> output_count;
  1498. for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
  1499. auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
  1500. output_count[output_index]++;
  1501. }
  1502. for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
  1503. auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
  1504. if (output_count[output_index] <= 1) {
  1505. continue;
  1506. }
  1507. auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
  1508. MS_EXCEPTION_IF_NULL(next_op);
  1509. auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
  1510. if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
  1511. MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
  1512. << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
  1513. return FAILED;
  1514. }
  1515. output_count[output_index]--;
  1516. }
  1517. }
  1518. }
  1519. }
  1520. return SUCCESS;
  1521. }
  1522. Status CostGraph::CalculateMemoryCost() {
  1523. if (RUN_PHASE == TRAINING_PHASE) {
  1524. // training phase
  1525. if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
  1526. // Calculate operators' memory usage
  1527. if (CalculateOpsMemoryCost() != SUCCESS) {
  1528. MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
  1529. return FAILED;
  1530. }
  1531. // Calculate edges' memory usage
  1532. if (CalculateEdgesMemoryCost() != SUCCESS) {
  1533. MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
  1534. return FAILED;
  1535. }
  1536. // Correct memory usage caused by TmpIdentity
  1537. if (CorrectOpsMemoryCost() != SUCCESS) {
  1538. MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
  1539. return FAILED;
  1540. }
  1541. } else {
  1542. MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
  1543. return FAILED;
  1544. }
  1545. } else {
  1546. // inference phase
  1547. if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
  1548. // Calculate operators' memory usage
  1549. if (CalculateOpsMemoryCostForInference() != SUCCESS) {
  1550. MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
  1551. return FAILED;
  1552. }
  1553. // Calculate edges's memory usage
  1554. if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
  1555. MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
  1556. return FAILED;
  1557. }
  1558. } else {
  1559. MS_LOG(ERROR) << "Computing operators' critical flag failed.";
  1560. return FAILED;
  1561. }
  1562. }
  1563. return SUCCESS;
  1564. }
  1565. } // namespace parallel
  1566. } // namespace mindspore