You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_costmodel.cc 75 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <cstdlib>
  18. #include <iterator>
  19. #include <numeric>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "parallel/auto_parallel/graph_costmodel.h"
  24. #include "parallel/ops_info/reshape_info.h"
  25. #include "parallel/step_auto_parallel.h"
  26. namespace mindspore {
  27. namespace parallel {
  28. CostGraphPtr entire_costgraph = nullptr;
  29. size_t TOTAL_OPS = 0;
  30. double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA;
  31. bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
  32. double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY;
  33. double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
  34. double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST;
  35. double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS;
  36. bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
  37. size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
  38. bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
  39. bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
  40. bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
  41. int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
  42. constexpr char RESHAPEINFO[] = "ReshapeInfo";
  43. void CostGraph::SetDeviceMemoryAndCostParameter() {
  44. MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
  45. // DEVICE_MEMORY_CAPACITY
  46. auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
  47. if (device_memory <= 0) {
  48. MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
  49. }
  50. dev_memory_ = device_memory;
  51. DEVICE_MEMORY_CAPACITY = device_memory;
  52. MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << ".";
  53. // COST_MODEL_ALPHA
  54. auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
  55. if (alpha <= 0) {
  56. MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
  57. }
  58. costmodel_alpha_ = alpha;
  59. MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
  60. // COST_MODEL_BETA
  61. auto beta = CostModelContext::GetInstance()->costmodel_beta();
  62. if (beta <= 0) {
  63. MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
  64. }
  65. costmodel_beta_ = beta;
  66. MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
  67. // COST_MODEL_GAMMA
  68. auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
  69. if ((gamma < 0) || (gamma > 1)) {
  70. MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
  71. }
  72. COST_MODEL_GAMMA = gamma;
  73. MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << ".";
  74. // COST_MODEL_SIMPLIFY_CALCULATION
  75. auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal();
  76. COST_MODEL_SIMPLIFY_CALCULATION = simplify;
  77. if (COST_MODEL_SIMPLIFY_CALCULATION) {
  78. MS_LOG(INFO) << "costmodel_simplify_cal: true.";
  79. } else {
  80. MS_LOG(INFO) << "costmodel_simplify_cal: false.";
  81. }
  82. // COST_MODEL_COMMUNI_THRESHOLD
  83. auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
  84. if (communi_threshold < 0) {
  85. MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
  86. }
  87. COST_MODEL_COMMUNI_THRESHOLD = communi_threshold;
  88. MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << ".";
  89. // COST_MODEL_COMMUNI_CONST
  90. auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const();
  91. if (communi_const < 0) {
  92. MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
  93. }
  94. COST_MODEL_COMMUNI_CONST = communi_const;
  95. MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << ".";
  96. // COST_MODEL_COMMUNI_BIAS
  97. auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
  98. if (communi_bias < 0) {
  99. MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
  100. }
  101. COST_MODEL_COMMUNI_BIAS = communi_bias;
  102. MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << ".";
  103. // TENSOR_SLICE_ALIGNMENT_ENABLE
  104. auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
  105. TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable;
  106. if (TENSOR_SLICE_ALIGNMENT_ENABLE) {
  107. MS_LOG(INFO) << "tensor_slice_align_enable: true.";
  108. } else {
  109. MS_LOG(INFO) << "tensor_slice_align_enable: false.";
  110. }
  111. // TENSOR_SLICE_ALIGNMENT_SIZE
  112. auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
  113. if (align_size == 0) {
  114. MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
  115. }
  116. TENSOR_SLICE_ALIGNMENT_SIZE = align_size;
  117. MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << ".";
  118. // FULLY_USE_DEVICES
  119. auto fully_devices = CostModelContext::GetInstance()->fully_use_device();
  120. FULLY_USE_DEVICES = fully_devices;
  121. if (FULLY_USE_DEVICES) {
  122. MS_LOG(INFO) << "fully_use_devices: true.";
  123. } else {
  124. MS_LOG(INFO) << "fully_use_devices: false.";
  125. }
  126. // ELEMENTWISE_OP_STRA_FOLLOW
  127. auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
  128. ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow;
  129. if (ELEMENTWISE_OP_STRA_FOLLOW) {
  130. MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
  131. } else {
  132. MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
  133. }
  134. // MULTI_SUBGRAPHS
  135. auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
  136. MULTI_SUBGRAPHS = multi_subgraphs;
  137. if (MULTI_SUBGRAPHS) {
  138. MS_LOG(INFO) << "multi_subgraphs: true.";
  139. } else {
  140. MS_LOG(INFO) << "multi_subgraphs: false.";
  141. }
  142. // RUN_PHASE
  143. auto phase = CostModelContext::GetInstance()->run_phase();
  144. if (phase != 0 && phase != 1) {
  145. MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
  146. }
  147. RUN_PHASE = phase;
  148. MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
  149. }
  150. void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
  151. for (auto it = ops_.begin(); it != ops_.end();) {
  152. if ((*it) == op) {
  153. it = ops_.erase(it);
  154. } else {
  155. ++it;
  156. }
  157. }
  158. }
  159. bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
  160. struct IsInGraph {
  161. const OperatorInfoPtr test_;
  162. explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
  163. bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
  164. };
  165. return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
  166. }
  167. void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
  168. std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
  169. curr_edges.push_back(edge);
  170. edges_[{u_node, v_node}] = curr_edges;
  171. std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
  172. curr_out_edges.push_back(edge);
  173. out_edges_[u_node] = curr_out_edges;
  174. std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
  175. curr_in_edges.push_back(edge);
  176. in_edges_[v_node] = curr_in_edges;
  177. }
  178. bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
  179. for (auto &edge_pair : edges_) {
  180. auto edges = edge_pair.second;
  181. for (auto &edge : edges) {
  182. MS_EXCEPTION_IF_NULL(edge);
  183. bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
  184. (edge->next_op_input_index() == input_index);
  185. if (bool_result) {
  186. return true;
  187. }
  188. }
  189. }
  190. return false;
  191. }
  192. std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
  193. std::vector<OperatorInfoPtr> alive_ops) {
  194. std::map<OperatorInfoPtr, bool> visited;
  195. for (auto &op : alive_ops) {
  196. visited[op] = false;
  197. }
  198. MS_LOG(INFO) << "visited: " << visited.size() << ".";
  199. for (auto &op : alive_ops) {
  200. if ((!visited[op]) && op->is_alive()) {
  201. std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
  202. MS_EXCEPTION_IF_NULL(new_component);
  203. new_component->SetDeviceMemoryAndCostParameter();
  204. DFS(op, &visited, new_component);
  205. connected_compoents_.push_back(new_component);
  206. }
  207. }
  208. return connected_compoents_;
  209. }
  210. void CostGraph::DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
  211. const std::shared_ptr<CostGraph> &component) {
  212. MS_EXCEPTION_IF_NULL(visited);
  213. MS_EXCEPTION_IF_NULL(component);
  214. visited->at(current_op) = true;
  215. component->AddOperator(current_op);
  216. for (auto &edge : current_op->succ_edges()) {
  217. bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
  218. (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
  219. if (bool_test) {
  220. component->AddEdge(current_op, edge->next_operator(), edge);
  221. DFS(edge->next_operator(), visited, component);
  222. }
  223. }
  224. for (auto &edge : current_op->prev_edges()) {
  225. bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
  226. (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
  227. if (bool_test) {
  228. component->AddEdge(edge->prev_operator(), current_op, edge);
  229. DFS(edge->prev_operator(), visited, component);
  230. }
  231. }
  232. }
  233. // Create final cost list for the graph: u --> v
  234. CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
  235. const OperatorInfoPtr &v) {
  236. MS_EXCEPTION_IF_NULL(u);
  237. MS_EXCEPTION_IF_NULL(v);
  238. MS_EXCEPTION_IF_NULL(e);
  239. CostPtrList ret;
  240. for (const auto &u_strategy : u->GetStrategyCost()) {
  241. for (const auto &v_strategy : v->GetStrategyCost()) {
  242. MS_EXCEPTION_IF_NULL(u_strategy);
  243. MS_EXCEPTION_IF_NULL(v_strategy);
  244. auto u_strategy_ptr = u_strategy->strategy_ptr;
  245. auto v_strategy_ptr = v_strategy->strategy_ptr;
  246. CostPtrList clist1 = u_strategy->cost_list;
  247. CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
  248. CostPtrList clist3 = v_strategy->cost_list;
  249. for (const auto &cost1 : clist1) {
  250. for (const auto &cost2 : clist2) {
  251. for (const auto &cost3 : clist3) {
  252. MS_EXCEPTION_IF_NULL(cost1);
  253. MS_EXCEPTION_IF_NULL(cost2);
  254. MS_EXCEPTION_IF_NULL(cost3);
  255. double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
  256. double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
  257. double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
  258. double communication_forward =
  259. cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
  260. double communication_without_para = cost1->communication_without_parameter_ +
  261. cost2->communication_without_parameter_ +
  262. cost3->communication_without_parameter_;
  263. auto decision =
  264. std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
  265. auto cost = std::make_shared<Cost>(computation, communication, decision);
  266. MS_EXCEPTION_IF_NULL(cost);
  267. cost->communication_without_parameter_ = communication_without_para;
  268. cost->communication_with_partial_para_ =
  269. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  270. cost->memory_with_reuse_ = memory;
  271. cost->communication_forward_ = communication_forward;
  272. ret.push_back(cost);
  273. }
  274. }
  275. }
  276. }
  277. }
  278. Simplify(&ret);
  279. return ret;
  280. }
  281. // Create final cost list for the graph containing a signle node: u
  282. CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
  283. MS_EXCEPTION_IF_NULL(u);
  284. CostPtrList ret;
  285. for (const auto &u_strategy : u->GetStrategyCost()) {
  286. MS_EXCEPTION_IF_NULL(u_strategy);
  287. auto u_strategy_ptr = u_strategy->strategy_ptr;
  288. CostPtrList clist1 = u_strategy->cost_list;
  289. for (const auto &cost1 : clist1) {
  290. MS_EXCEPTION_IF_NULL(cost1);
  291. auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
  292. auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
  293. MS_EXCEPTION_IF_NULL(new_cost);
  294. new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
  295. new_cost->communication_with_partial_para_ =
  296. cost1->communication_without_parameter_ +
  297. COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
  298. new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
  299. new_cost->communication_forward_ = cost1->communication_forward_;
  300. ret.push_back(new_cost);
  301. }
  302. }
  303. Simplify(&ret);
  304. return ret;
  305. }
  306. CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
  307. // Select the cost with minimum inference time. Currently, the inference time is modeled as =
  308. // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
  309. if (cost_list.empty()) {
  310. MS_LOG(ERROR) << "Final cost list is null.";
  311. return nullptr;
  312. }
  313. CostPtrList after_mem_filter;
  314. double minimum_memory = DBL_MAX;
  315. // Filter out the valid costs.
  316. for (auto &a_cost : cost_list) {
  317. if (a_cost->memory_with_reuse_ <= memory) {
  318. after_mem_filter.emplace_back(std::move(a_cost));
  319. } else if (a_cost->memory_with_reuse_ < minimum_memory) {
  320. minimum_memory = a_cost->memory_with_reuse_;
  321. }
  322. }
  323. if (after_mem_filter.empty()) {
  324. MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
  325. << ", the memory capacity is: " << memory << ".";
  326. return nullptr;
  327. }
  328. // Init the returned value with first cost.
  329. CostPtr ret = after_mem_filter[0];
  330. double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
  331. MS_LOG(INFO) << "Cost 0: "
  332. << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
  333. << ", communication_forward_: " << ret->communication_forward_
  334. << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
  335. << ", communication_cost_: " << ret->communication_cost_
  336. << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
  337. MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
  338. for (size_t i = 1; i < after_mem_filter.size(); ++i) {
  339. MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
  340. MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
  341. << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
  342. << ", communication_forward_: " << after_mem_filter[i]->communication_forward_
  343. << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
  344. << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
  345. << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
  346. << ".";
  347. auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
  348. costmodel_beta_ * after_mem_filter[i]->communication_forward_;
  349. MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
  350. if (minimum > tmp) {
  351. minimum = tmp;
  352. ret = after_mem_filter[i];
  353. MS_LOG(INFO) << "Selected: " << i;
  354. }
  355. }
  356. return ret;
  357. }
  358. CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
  359. // Select the cost with minimum training time. Currently, the training time is modeled as =
  360. // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
  361. if (cost_list.empty()) {
  362. MS_LOG(ERROR) << "Final cost list is null.";
  363. return nullptr;
  364. }
  365. CostPtrList after_mem_filter;
  366. double minimum_memory = DBL_MAX;
  367. // Filter out the valid costs.
  368. for (auto &a_cost : cost_list) {
  369. if (a_cost->memory_with_reuse_ <= memory) {
  370. after_mem_filter.emplace_back(std::move(a_cost));
  371. } else if (a_cost->memory_with_reuse_ < minimum_memory) {
  372. minimum_memory = a_cost->memory_with_reuse_;
  373. }
  374. }
  375. if (after_mem_filter.empty()) {
  376. MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
  377. << ", the memory capacity is: " << memory << ".";
  378. return nullptr;
  379. }
  380. // Init the returned value with first cost.
  381. CostPtr ret = after_mem_filter[0];
  382. double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
  383. MS_LOG(INFO) << "Cost 0: "
  384. << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
  385. << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
  386. << ", communication_cost_: " << ret->communication_cost_
  387. << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
  388. MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
  389. for (size_t i = 1; i < after_mem_filter.size(); ++i) {
  390. MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
  391. MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
  392. << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
  393. << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
  394. << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
  395. << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
  396. << ".";
  397. auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
  398. costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
  399. MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
  400. if (minimum > tmp) {
  401. minimum = tmp;
  402. ret = after_mem_filter[i];
  403. MS_LOG(INFO) << "Selected: " << i;
  404. }
  405. }
  406. return ret;
  407. }
  408. CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
  409. double available_memory) {
  410. CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
  411. double minimum = DBL_MAX, total_memory = 0.0;
  412. CostPtrList ret(all_cost_list.size(), nullptr);
  413. // Check whether valid costs exist.
  414. for (size_t i = 0; i < all_cost_list.size(); ++i) {
  415. if (all_cost_list[i][0] == nullptr) {
  416. MS_LOG(ERROR) << "The cost list " << i << " is empty.";
  417. return ret;
  418. } else {
  419. double memory_i_cost = DBL_MAX;
  420. for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
  421. if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
  422. memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
  423. }
  424. }
  425. total_memory += memory_i_cost;
  426. }
  427. }
  428. if (total_memory >= available_memory) {
  429. MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
  430. << ", minimum strategy cost: " << total_memory << ".";
  431. return selected_cost_list;
  432. }
  433. std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
  434. &available_memory, this](size_t k) {
  435. if (k == all_cost_list.size()) {
  436. double tmp_memory = 0.0, tmp_minimum = 0.0;
  437. for (size_t i = 0; i < selected_cost_list.size(); ++i) {
  438. MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
  439. tmp_memory += selected_cost_list[i]->memory_with_reuse_;
  440. tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
  441. costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
  442. }
  443. MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
  444. << ".";
  445. if (tmp_memory < available_memory && tmp_minimum < minimum) {
  446. ret = selected_cost_list;
  447. minimum = tmp_minimum;
  448. MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
  449. }
  450. return;
  451. }
  452. MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
  453. for (auto &c : all_cost_list[k]) {
  454. selected_cost_list[k] = c;
  455. recursive(k + 1);
  456. }
  457. };
  458. recursive(0);
  459. return ret;
  460. }
  461. Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
  462. MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
  463. auto connected_components = ConstructConnectedComponents(alive_ops);
  464. MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
  465. std::vector<CostPtrList> all_list;
  466. for (size_t j = 0; j < connected_components.size(); ++j) {
  467. auto one_component = connected_components[j];
  468. MS_EXCEPTION_IF_NULL(one_component);
  469. if (one_component->GetOperators().size() == 1) {
  470. MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
  471. auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
  472. all_list.push_back(cost_list);
  473. } else if (one_component->GetOperators().size() == 2) {
  474. MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
  475. OperatorInfoPtr u, v;
  476. auto first_op = one_component->GetOperators()[0];
  477. auto second_op = one_component->GetOperators()[1];
  478. MS_EXCEPTION_IF_NULL(first_op);
  479. MS_EXCEPTION_IF_NULL(second_op);
  480. if (!first_op->GetAliveSuccEdges().empty() &&
  481. first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
  482. u = first_op;
  483. v = second_op;
  484. } else if (!second_op->GetAliveSuccEdges().empty() &&
  485. second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
  486. u = second_op;
  487. v = first_op;
  488. } else {
  489. MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
  490. << ", " << second_op->GetAliveSuccEdges().size() << ".";
  491. }
  492. MS_EXCEPTION_IF_NULL(u);
  493. auto e = u->GetAliveSuccEdges()[0];
  494. auto cost_list = one_component->CreateFinalCostList(u, e, v);
  495. all_list.push_back(cost_list);
  496. } else {
  497. MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
  498. << " operators in a component in the final graph.";
  499. }
  500. }
  501. //
  502. auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
  503. for (size_t k = 0; k < selected_cost_list.size(); ++k) {
  504. auto selected_cost = selected_cost_list[k];
  505. if (selected_cost == nullptr) {
  506. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  507. return FAILED;
  508. }
  509. MS_EXCEPTION_IF_NULL(connected_components[k]);
  510. if (connected_components[k]->GetOperators().size() == 1) {
  511. auto u = connected_components[k]->GetOperators()[0];
  512. auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  513. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  514. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  515. } else if (connected_components[k]->GetOperators().size() == 2) {
  516. OperatorInfoPtr u = nullptr, v = nullptr;
  517. auto first_op = connected_components[k]->GetOperators()[0];
  518. auto second_op = connected_components[k]->GetOperators()[1];
  519. MS_EXCEPTION_IF_NULL(first_op);
  520. MS_EXCEPTION_IF_NULL(second_op);
  521. if (!first_op->GetAliveSuccEdges().empty() &&
  522. first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
  523. u = first_op;
  524. v = second_op;
  525. } else if (!second_op->GetAliveSuccEdges().empty() &&
  526. second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
  527. u = second_op;
  528. v = first_op;
  529. }
  530. MS_EXCEPTION_IF_NULL(u);
  531. auto e = u->GetAliveSuccEdges()[0];
  532. MS_EXCEPTION_IF_NULL(v);
  533. MS_EXCEPTION_IF_NULL(e);
  534. MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
  535. auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
  536. MS_EXCEPTION_IF_NULL(decision);
  537. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
  538. v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
  539. e->set_selected_cost(decision->middle_cost_);
  540. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  541. }
  542. }
  543. return SUCCESS;
  544. }
  545. // searching the strategy for the final eliminated graph
  546. Status CostGraph::SearchStrategy() {
  547. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
  548. std::vector<OperatorInfoPtr> alive_ops;
  549. (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
  550. MS_EXCEPTION_IF_NULL(op);
  551. if (op->is_alive()) {
  552. alive_ops.push_back(op);
  553. }
  554. });
  555. if (alive_ops.size() > 2) {
  556. if (RUN_PHASE == TRAINING_PHASE) {
  557. // training phase
  558. return SearchStrategyForMultiNodeFinalGraph(alive_ops);
  559. } else {
  560. // inference phase
  561. MS_LOG(EXCEPTION)
  562. << "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
  563. }
  564. } else if (alive_ops.size() == 1) {
  565. MS_LOG(INFO) << "There are 1 single node in the final graph.";
  566. OperatorInfoPtr u = alive_ops[0];
  567. auto cost_list = CreateFinalSingleCostList(u);
  568. CostPtr cost = nullptr;
  569. if (RUN_PHASE == TRAINING_PHASE) {
  570. // training phase
  571. cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
  572. } else {
  573. // inference phase
  574. cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
  575. }
  576. if (cost == nullptr) {
  577. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  578. return FAILED;
  579. }
  580. MS_EXCEPTION_IF_NULL(u);
  581. MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
  582. auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  583. MS_EXCEPTION_IF_NULL(decision);
  584. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  585. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
  586. return SUCCESS;
  587. } else {
  588. // In this case, the final graph should contains exactly 2 nodes.
  589. if (alive_ops.empty()) {
  590. MS_LOG(INFO) << "0 Operator in the final graph.";
  591. return SUCCESS;
  592. }
  593. OperatorInfoPtr u, v;
  594. MS_EXCEPTION_IF_NULL(alive_ops[0]);
  595. MS_EXCEPTION_IF_NULL(alive_ops[1]);
  596. if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
  597. alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
  598. u = alive_ops[0];
  599. v = alive_ops[1];
  600. } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
  601. alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
  602. u = alive_ops[1];
  603. v = alive_ops[0];
  604. } else {
  605. if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
  606. MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
  607. << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
  608. } else {
  609. // In this case, the final graph consists of two single nodes
  610. MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
  611. std::vector<CostPtrList> all_list;
  612. auto connected_components = ConstructConnectedComponents(alive_ops);
  613. MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
  614. for (size_t i = 0; i < connected_components.size(); ++i) {
  615. MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
  616. auto one_component = connected_components[i];
  617. MS_EXCEPTION_IF_NULL(one_component);
  618. auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
  619. all_list.push_back(cost_list);
  620. }
  621. CostPtrList selected_cost_list;
  622. if (RUN_PHASE == TRAINING_PHASE) {
  623. // training phase
  624. selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
  625. } else {
  626. // inference phase
  627. MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
  628. "phase is not supported.";
  629. }
  630. for (size_t k = 0; k < selected_cost_list.size(); ++k) {
  631. auto selected_cost = selected_cost_list[k];
  632. if (selected_cost == nullptr) {
  633. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  634. return FAILED;
  635. }
  636. MS_EXCEPTION_IF_NULL(connected_components[k]);
  637. auto one_operator = connected_components[k]->GetOperators()[0];
  638. MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
  639. auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
  640. MS_EXCEPTION_IF_NULL(decision);
  641. one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
  642. MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
  643. }
  644. return SUCCESS;
  645. }
  646. }
  647. MS_LOG(INFO) << "There are 2 nodes in the final graph.";
  648. // In this case, the finale graph is exactly of the form: u --> v
  649. MS_EXCEPTION_IF_NULL(u);
  650. MS_EXCEPTION_IF_NULL(v);
  651. auto e = u->GetAliveSuccEdges()[0];
  652. MS_EXCEPTION_IF_NULL(e);
  653. auto cost_list = CreateFinalCostList(u, e, v);
  654. CostPtr cost = nullptr;
  655. if (RUN_PHASE == TRAINING_PHASE) {
  656. // training phase
  657. cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
  658. } else {
  659. MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
  660. "phase is not supported.";
  661. }
  662. if (cost == nullptr) {
  663. MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
  664. return FAILED;
  665. }
  666. MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
  667. auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
  668. MS_EXCEPTION_IF_NULL(decision);
  669. u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
  670. v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
  671. e->set_selected_cost(decision->middle_cost_);
  672. MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
  673. return SUCCESS;
  674. }
  675. }
  676. // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
  677. // return the v and the edge u --> v
  678. OperatorInfoPtr CostGraph::CheckOpElimination() const {
  679. for (auto &op : ops_) {
  680. bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
  681. if (bool_test) {
  682. if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
  683. return op;
  684. }
  685. }
  686. }
  687. return nullptr;
  688. }
  689. // Check the graph whether an EdgeElimination can be performed
  690. std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
  691. for (auto &op : ops_) {
  692. MS_EXCEPTION_IF_NULL(op);
  693. if (!op->is_alive()) continue;
  694. std::map<void *, int> count;
  695. for (auto &edge : op->GetAliveSuccEdges()) {
  696. MS_EXCEPTION_IF_NULL(edge);
  697. auto v = edge->next_operator();
  698. count[v.get()]++;
  699. }
  700. for (auto &pair : count) {
  701. auto *op_ptr = pair.first;
  702. int op_count = pair.second;
  703. if (op_count > 1) {
  704. std::vector<std::shared_ptr<Edge>> ret;
  705. for (auto &edge : op->GetAliveSuccEdges()) {
  706. MS_EXCEPTION_IF_NULL(edge);
  707. if (edge->next_operator().get() == op_ptr) {
  708. ret.push_back(edge);
  709. }
  710. }
  711. return ret;
  712. }
  713. }
  714. }
  715. return {};
  716. }
  717. // Check the graph whether a MergeElimination can be performed
  718. OperatorInfoPtr CostGraph::CheckMergeElimination() const {
  719. for (auto &op : ops_) {
  720. MS_EXCEPTION_IF_NULL(op);
  721. bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
  722. if (bool_test) {
  723. auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
  724. MS_EXCEPTION_IF_NULL(next_op);
  725. if (!next_op->GetAlivePrevEdges().empty()) {
  726. return op;
  727. }
  728. }
  729. }
  730. return nullptr;
  731. }
  732. // Check the graph whether a ContractElimination can be performed
  733. OperatorInfoPtr CostGraph::CheckContractElimination() const {
  734. for (auto &op : ops_) {
  735. MS_EXCEPTION_IF_NULL(op);
  736. bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
  737. if (bool_test) {
  738. auto edge = op->GetAlivePrevEdges()[0];
  739. MS_EXCEPTION_IF_NULL(edge);
  740. auto prev_op = edge->prev_operator();
  741. MS_EXCEPTION_IF_NULL(prev_op);
  742. if (!prev_op->GetAliveSuccEdges().empty()) {
  743. return op;
  744. }
  745. }
  746. }
  747. return nullptr;
  748. }
  749. // Check the graph whether a TriangleElimination can be performed
  750. std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
  751. for (auto &op : ops_) {
  752. MS_EXCEPTION_IF_NULL(op);
  753. bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
  754. if (bool_test) {
  755. auto edge1 = op->GetAliveSuccEdges()[0];
  756. auto edge2 = op->GetAliveSuccEdges()[1];
  757. MS_EXCEPTION_IF_NULL(edge1);
  758. MS_EXCEPTION_IF_NULL(edge2);
  759. auto first_op = edge1->next_operator();
  760. auto second_op = edge2->next_operator();
  761. MS_EXCEPTION_IF_NULL(first_op);
  762. for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
  763. if (first_op_succ_edge->next_operator() == second_op) {
  764. return {op, first_op_succ_edge};
  765. }
  766. }
  767. MS_EXCEPTION_IF_NULL(second_op);
  768. for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
  769. if (second_op_succ_edge->next_operator() == first_op) {
  770. return {op, second_op_succ_edge};
  771. }
  772. }
  773. }
  774. }
  775. return {nullptr, nullptr};
  776. }
  777. // Check the graph whether a StarElimination can be performed.
  778. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
  779. OperatorInfoPtr CostGraph::CheckStarElimination() const {
  780. for (auto &op : ops_) {
  781. MS_EXCEPTION_IF_NULL(op);
  782. bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
  783. if (bool_test) {
  784. return op;
  785. }
  786. }
  787. return nullptr;
  788. }
  789. // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
  790. // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
  791. std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) {
  792. // in this case, the operators are organised in the form of u-->op-->v, and the goal
  793. // is to eliminate 'op'.
  794. MS_EXCEPTION_IF_NULL(op);
  795. MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
  796. auto edge_u_op = op->GetAlivePrevEdges()[0];
  797. auto edge_op_v = op->GetAliveSuccEdges()[0];
  798. MS_EXCEPTION_IF_NULL(edge_u_op);
  799. MS_EXCEPTION_IF_NULL(edge_op_v);
  800. auto u = edge_u_op->prev_operator();
  801. auto v = edge_op_v->next_operator();
  802. std::vector<size_t> output_indexs, input_indexs;
  803. size_t output_index, input_index;
  804. MS_EXCEPTION_IF_NULL(u);
  805. MS_EXCEPTION_IF_NULL(v);
  806. std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
  807. std::shared_ptr<Edge> new_edge;
  808. if (edge_u_op->is_combined()) {
  809. output_indexs = edge_u_op->prev_op_output_indexs();
  810. } else {
  811. output_index = edge_u_op->prev_op_output_index();
  812. output_indexs.push_back(output_index);
  813. }
  814. if (edge_op_v->is_combined()) {
  815. input_indexs = edge_op_v->next_op_input_indexs();
  816. } else {
  817. input_index = edge_op_v->next_op_input_index();
  818. input_indexs.push_back(input_index);
  819. }
  820. if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
  821. new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
  822. } else {
  823. new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
  824. }
  825. MS_EXCEPTION_IF_NULL(new_edge);
  826. new_edge->set_pre_op_output(edge_u_op->prev_op_output());
  827. new_edge->set_next_op_input(edge_op_v->next_op_input());
  828. new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
  829. u->ReplaceSuccEdge(op, new_edge);
  830. v->ReplacePreEdge(op, new_edge);
  831. op->SetNotAlive();
  832. MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
  833. return new_edge;
  834. }
  835. // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
  836. // and sets new costlist for the new edge.
  837. std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) {
  838. MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
  839. MS_EXCEPTION_IF_NULL(edges[0]);
  840. auto u = edges[0]->prev_operator();
  841. auto v = edges[0]->next_operator();
  842. MS_EXCEPTION_IF_NULL(u);
  843. MS_EXCEPTION_IF_NULL(v);
  844. std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
  845. std::vector<size_t> output_indexs, input_indexs;
  846. for (auto &edge : edges) {
  847. MS_EXCEPTION_IF_NULL(edge);
  848. if (edge->is_combined()) {
  849. auto from_output_indexs = edge->prev_op_output_indexs();
  850. auto from_input_indexs = edge->next_op_input_indexs();
  851. (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
  852. (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
  853. } else {
  854. output_indexs.push_back(edge->prev_op_output_index());
  855. input_indexs.push_back(edge->next_op_input_index());
  856. }
  857. }
  858. std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
  859. MS_EXCEPTION_IF_NULL(new_edge);
  860. new_edge->set_pre_op_output(edges[0]->prev_op_output());
  861. new_edge->set_next_op_input(edges[0]->next_op_input());
  862. new_edge->EdgeEliminationSetNewCost(u, edges, v);
  863. u->ReplaceSuccEdges(v, new_edge);
  864. v->ReplacePreEdges(u, new_edge);
  865. MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
  866. return new_edge;
  867. }
  868. // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
  869. // for this contract under the strategy 'op_strategy'
  870. void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
  871. const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
  872. const CostPtrList &tar_cost_list,
  873. CostPtrList *const tar_cost_list_new) {
  874. for (size_t i = 0; i < op_cost_list.size(); ++i) {
  875. auto &op_cost = op_cost_list[i];
  876. MS_EXCEPTION_IF_NULL(op_cost);
  877. for (size_t j = 0; j < edge_cost_list.size(); ++j) {
  878. auto &edge_cost = edge_cost_list[j];
  879. MS_EXCEPTION_IF_NULL(edge_cost);
  880. for (size_t k = 0; k < tar_cost_list.size(); ++k) {
  881. auto &tar_cost = tar_cost_list[k];
  882. MS_EXCEPTION_IF_NULL(tar_cost);
  883. double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
  884. double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
  885. double communication =
  886. op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
  887. double communication_forward =
  888. op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
  889. double communication_without_para = op_cost->communication_without_parameter_ +
  890. edge_cost->communication_without_parameter_ +
  891. tar_cost->communication_without_parameter_;
  892. auto decision =
  893. std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
  894. auto new_cost = std::make_shared<Cost>(computation, communication, decision);
  895. MS_EXCEPTION_IF_NULL(new_cost);
  896. new_cost->communication_without_parameter_ = communication_without_para;
  897. new_cost->communication_with_partial_para_ =
  898. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  899. new_cost->memory_with_reuse_ = memory;
  900. new_cost->communication_forward_ = communication_forward;
  901. MS_EXCEPTION_IF_NULL(tar_cost_list_new);
  902. tar_cost_list_new->emplace_back(std::move(new_cost));
  903. }
  904. }
  905. }
  906. }
  907. // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
  908. // target_op
  909. OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
  910. MS_EXCEPTION_IF_NULL(op);
  911. auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
  912. auto edge_ptr = op->GetAliveSuccEdges()[0];
  913. MS_EXCEPTION_IF_NULL(target_op);
  914. MS_EXCEPTION_IF_NULL(edge_ptr);
  915. MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
  916. bool valid = false;
  917. for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
  918. MS_EXCEPTION_IF_NULL(tar_stra_cost);
  919. auto tar_stra = tar_stra_cost->strategy_ptr;
  920. auto tar_clist_origin = tar_stra_cost->cost_list;
  921. CostPtrList tar_clist_new;
  922. for (auto &op_stra_cost : op->GetStrategyCost()) {
  923. MS_EXCEPTION_IF_NULL(op_stra_cost);
  924. auto op_stra = op_stra_cost->strategy_ptr;
  925. auto op_clist = op_stra_cost->cost_list;
  926. auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
  927. CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
  928. }
  929. Simplify(&tar_clist_new);
  930. // Set the new costlist w.r.t the strategy
  931. tar_stra_cost->cost_list = tar_clist_new;
  932. if ((!valid) && (!tar_clist_new.empty())) {
  933. valid = true;
  934. }
  935. }
  936. if (!valid) {
  937. MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
  938. }
  939. op->SetNotAlive();
  940. MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
  941. return target_op;
  942. }
  943. // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
  944. // for this contract under the strategy 'contract_op_stra'
  945. void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
  946. const CostPtrList &contract_op_cost_list,
  947. const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
  948. const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) {
  949. for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
  950. auto &contract_op_cost = contract_op_cost_list[i];
  951. MS_EXCEPTION_IF_NULL(contract_op_cost);
  952. for (size_t j = 0; j < edge_cost_list.size(); ++j) {
  953. auto &edge_cost = edge_cost_list[j];
  954. MS_EXCEPTION_IF_NULL(edge_cost);
  955. for (size_t k = 0; k < tar_cost_list.size(); ++k) {
  956. auto &tar_cost = tar_cost_list[k];
  957. MS_EXCEPTION_IF_NULL(tar_cost);
  958. double computation =
  959. contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
  960. double memory =
  961. contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
  962. double communication =
  963. contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
  964. double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
  965. tar_cost->communication_forward_;
  966. double communication_without_para = contract_op_cost->communication_without_parameter_ +
  967. edge_cost->communication_without_parameter_ +
  968. tar_cost->communication_without_parameter_;
  969. auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
  970. target_op_stra, tar_cost);
  971. auto new_cost = std::make_shared<Cost>(computation, communication, decision);
  972. new_cost->communication_without_parameter_ = communication_without_para;
  973. new_cost->communication_with_partial_para_ =
  974. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  975. new_cost->memory_with_reuse_ = memory;
  976. new_cost->communication_forward_ = communication_forward;
  977. tar_cost_list_new->emplace_back(std::move(new_cost));
  978. }
  979. }
  980. }
  981. }
  982. // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
  983. // target_op
  984. OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
  985. MS_EXCEPTION_IF_NULL(op);
  986. auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
  987. auto edge_ptr = op->GetAlivePrevEdges()[0];
  988. MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
  989. bool valid = false;
  990. for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
  991. MS_EXCEPTION_IF_NULL(tar_stra_cost);
  992. auto tar_stra = tar_stra_cost->strategy_ptr;
  993. auto tar_clist_origin = tar_stra_cost->cost_list;
  994. CostPtrList tar_clist_new;
  995. for (auto &op_stra_cost : op->GetStrategyCost()) {
  996. MS_EXCEPTION_IF_NULL(op_stra_cost);
  997. auto op_stra = op_stra_cost->strategy_ptr;
  998. auto op_clist = op_stra_cost->cost_list;
  999. auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
  1000. CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
  1001. }
  1002. Simplify(&tar_clist_new);
  1003. // Set the new costlist w.r.t the strategy
  1004. tar_stra_cost->cost_list = tar_clist_new;
  1005. if ((!valid) && (!tar_clist_new.empty())) {
  1006. valid = true;
  1007. }
  1008. }
  1009. if (!valid) {
  1010. MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
  1011. }
  1012. op->SetNotAlive();
  1013. MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
  1014. return target_op;
  1015. }
  1016. void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
  1017. StrategyPtr right_op_stra, const CostPtr &right_op_cost,
  1018. const CostPtrList &elimi_op_clist,
  1019. const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
  1020. const CostPtrList &left_node_clist_origin,
  1021. CostPtrList *left_node_clist_new) {
  1022. MS_EXCEPTION_IF_NULL(right_edge_cost);
  1023. MS_EXCEPTION_IF_NULL(right_op_cost);
  1024. MS_EXCEPTION_IF_NULL(left_node_clist_new);
  1025. for (auto &elimi_op_cost : elimi_op_clist) {
  1026. MS_EXCEPTION_IF_NULL(elimi_op_cost);
  1027. for (auto &left_edge_cost : left_edge_clist) {
  1028. MS_EXCEPTION_IF_NULL(left_edge_cost);
  1029. for (auto &left_node_cost : left_node_clist_origin) {
  1030. MS_EXCEPTION_IF_NULL(left_node_cost);
  1031. double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
  1032. left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
  1033. double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
  1034. left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
  1035. double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
  1036. left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
  1037. double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
  1038. left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
  1039. double new_commu_without =
  1040. elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
  1041. left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
  1042. auto decision = std::make_shared<TriangleEliminationDecision>(
  1043. elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
  1044. auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
  1045. new_cost->communication_without_parameter_ = new_commu_without;
  1046. new_cost->communication_with_partial_para_ =
  1047. new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
  1048. new_cost->memory_with_reuse_ = new_memory;
  1049. new_cost->communication_forward_ = new_commu_forward;
  1050. left_node_clist_new->emplace_back(std::move(new_cost));
  1051. }
  1052. }
  1053. }
  1054. }
  1055. void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
  1056. const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
  1057. const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
  1058. const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
  1059. const CostPtrList &left_node_clist_origin,
  1060. CostPtrList *left_node_clist_new) {
  1061. MS_EXCEPTION_IF_NULL(elimi_op);
  1062. for (auto &right_node_cost : right_node_clist) {
  1063. MS_EXCEPTION_IF_NULL(right_node_cost);
  1064. for (auto &right_edge_cost : right_edge_clist) {
  1065. MS_EXCEPTION_IF_NULL(right_edge_cost);
  1066. CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
  1067. elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
  1068. left_node_clist_new);
  1069. }
  1070. }
  1071. }
  1072. OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
  1073. const std::shared_ptr<Edge> &edge_left_right) {
  1074. MS_EXCEPTION_IF_NULL(edge_left_right);
  1075. MS_EXCEPTION_IF_NULL(elimi_op);
  1076. MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
  1077. auto left_node = edge_left_right->prev_operator();
  1078. auto right_node = edge_left_right->next_operator();
  1079. auto left_edge = elimi_op->GetAliveSuccEdges()[0];
  1080. auto right_edge = elimi_op->GetAliveSuccEdges()[1];
  1081. MS_EXCEPTION_IF_NULL(left_node);
  1082. MS_EXCEPTION_IF_NULL(right_node);
  1083. MS_EXCEPTION_IF_NULL(left_edge);
  1084. MS_EXCEPTION_IF_NULL(right_edge);
  1085. MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
  1086. MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
  1087. if (left_edge->next_operator() != left_node) {
  1088. auto tmp = left_edge;
  1089. left_edge = right_edge;
  1090. right_edge = tmp;
  1091. }
  1092. bool valid = false;
  1093. for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
  1094. MS_EXCEPTION_IF_NULL(left_node_stra_cost);
  1095. auto left_node_stra = left_node_stra_cost->strategy_ptr;
  1096. auto left_node_clist_origin = left_node_stra_cost->cost_list;
  1097. CostPtrList left_node_clist_new;
  1098. for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
  1099. MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
  1100. auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
  1101. auto elimi_op_clist = elimi_op_stra_cost->cost_list;
  1102. auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
  1103. for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
  1104. MS_EXCEPTION_IF_NULL(right_node_stra_cost);
  1105. auto right_node_stra = right_node_stra_cost->strategy_ptr;
  1106. auto right_node_clist = right_node_stra_cost->cost_list;
  1107. auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
  1108. CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
  1109. right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
  1110. &left_node_clist_new);
  1111. }
  1112. }
  1113. Simplify(&left_node_clist_new);
  1114. // Set the new costlist w.r.t the strategy
  1115. left_node_stra_cost->cost_list = left_node_clist_new;
  1116. if ((!valid) && (!left_node_clist_new.empty())) {
  1117. valid = true;
  1118. }
  1119. }
  1120. if (!valid) {
  1121. MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed.";
  1122. }
  1123. elimi_op->SetNotAlive();
  1124. MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
  1125. return left_node;
  1126. }
  1127. void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
  1128. const CostPtrList &first_succ_node_clist,
  1129. const CostPtrList &first_succ_edge_clist,
  1130. const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
  1131. std::vector<StrategyPtr> succ_nodes_stras,
  1132. CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
  1133. CostPtrList *first_succ_node_clist_new) {
  1134. for (auto &first_succ_node_cost : first_succ_node_clist) {
  1135. for (auto &first_succ_edge_cost : first_succ_edge_clist) {
  1136. for (auto &merged_node_cost : merged_op_clist) {
  1137. MS_EXCEPTION_IF_NULL(merged_node_cost);
  1138. succ_nodes_stras[0] = first_succ_node_stra;
  1139. succ_edges_costs[0] = first_succ_edge_cost;
  1140. succ_nodes_costs[0] = first_succ_node_cost;
  1141. double computation_cost = merged_node_cost->computation_cost_,
  1142. memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
  1143. commu_without = merged_node_cost->communication_without_parameter_,
  1144. commu_forward = merged_node_cost->communication_forward_;
  1145. for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
  1146. MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
  1147. if (i == 0) {
  1148. computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
  1149. memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
  1150. commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
  1151. commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
  1152. commu_without += succ_edges_costs[i]->communication_without_parameter_ +
  1153. succ_nodes_costs[i]->communication_without_parameter_;
  1154. } else {
  1155. computation_cost += succ_edges_costs[i]->computation_cost_;
  1156. memory_cost += succ_edges_costs[i]->memory_with_reuse_;
  1157. commu_cost += succ_edges_costs[i]->communication_cost_;
  1158. commu_forward += succ_edges_costs[i]->communication_forward_;
  1159. commu_without += succ_edges_costs[i]->communication_without_parameter_;
  1160. }
  1161. }
  1162. auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
  1163. succ_nodes_stras, succ_nodes_costs);
  1164. auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
  1165. new_cost->communication_without_parameter_ = commu_without;
  1166. new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
  1167. new_cost->memory_with_reuse_ = memory_cost;
  1168. new_cost->communication_forward_ = commu_forward;
  1169. first_succ_node_clist_new->emplace_back(std::move(new_cost));
  1170. }
  1171. }
  1172. }
  1173. }
  1174. void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
  1175. const StrategyPtr &first_succ_node_stra,
  1176. const CostPtrList &first_succ_node_clist,
  1177. const CostPtrList &first_succ_edge_clist,
  1178. const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
  1179. CostPtrList *first_succ_node_clist_new) {
  1180. std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
  1181. CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
  1182. std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
  1183. &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
  1184. &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
  1185. this](size_t k) {
  1186. if (k == succ_edges.size()) {
  1187. CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
  1188. merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
  1189. succ_nodes_costs, first_succ_node_clist_new);
  1190. return;
  1191. }
  1192. MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
  1193. << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
  1194. << ", merged_op_clist: " << merged_op_clist.size()
  1195. << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
  1196. auto succ_edge = succ_edges[k];
  1197. MS_EXCEPTION_IF_NULL(succ_edge);
  1198. auto succ_node = succ_edge->next_operator();
  1199. MS_EXCEPTION_IF_NULL(succ_node);
  1200. for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
  1201. MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
  1202. auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
  1203. auto succ_node_clist = succ_node_stra_cost->cost_list;
  1204. auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
  1205. for (auto &succ_node_cost : succ_node_clist) {
  1206. MS_EXCEPTION_IF_NULL(succ_node_cost);
  1207. for (auto &succ_edge_cost : succ_edge_clist) {
  1208. MS_EXCEPTION_IF_NULL(succ_edge_cost);
  1209. succ_nodes_stras[k] = succ_node_stra;
  1210. succ_edges_costs[k] = succ_edge_cost;
  1211. succ_nodes_costs[k] = succ_node_cost;
  1212. recursive(k + 1);
  1213. }
  1214. }
  1215. }
  1216. };
  1217. recursive(1);
  1218. }
  1219. std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) {
  1220. MS_EXCEPTION_IF_NULL(merged_op);
  1221. auto succ_edges = merged_op->GetAliveSuccEdges();
  1222. MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
  1223. for (auto &succ_edge : succ_edges) {
  1224. MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
  1225. MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
  1226. }
  1227. MS_EXCEPTION_IF_NULL(succ_edges[0]);
  1228. auto first_succ_node = succ_edges[0]->next_operator();
  1229. auto first_succ_edge = succ_edges[0];
  1230. bool valid = false;
  1231. // 'merged_op' is merged into first_node
  1232. MS_EXCEPTION_IF_NULL(first_succ_node);
  1233. for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
  1234. MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
  1235. auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
  1236. auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
  1237. CostPtrList first_succ_node_clist_new;
  1238. for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
  1239. MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
  1240. auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
  1241. auto merged_op_clist = merged_op_stra_cost->cost_list;
  1242. auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
  1243. CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
  1244. merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
  1245. }
  1246. Simplify(&first_succ_node_clist_new);
  1247. // Set the new costlist w.r.t the strategy
  1248. first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
  1249. if ((!valid) && (!first_succ_node_clist_new.empty())) {
  1250. valid = true;
  1251. }
  1252. }
  1253. if (!valid) {
  1254. MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed.";
  1255. }
  1256. merged_op->SetNotAlive();
  1257. MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
  1258. return succ_edges;
  1259. }
  1260. size_t CostGraph::GetNumEdges() const {
  1261. size_t sum = 0;
  1262. for (const auto &kv : edges_) {
  1263. auto &edges = kv.second;
  1264. sum += edges.size();
  1265. }
  1266. return sum;
  1267. }
  1268. Status CostGraph::InitSelectedStrategy() {
  1269. for (auto &op : ops_) {
  1270. MS_EXCEPTION_IF_NULL(op);
  1271. if (op->name().find(RESHAPEINFO) != std::string::npos) {
  1272. continue;
  1273. }
  1274. auto result = op->InitSelectedStrategy(op->selected_strategy());
  1275. if (result != SUCCESS) {
  1276. return result;
  1277. }
  1278. }
  1279. // reshape init should be apply after the init of it's previous node and next node.
  1280. for (size_t i = 0; i < ops_.size(); ++i) {
  1281. if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
  1282. auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
  1283. auto in_edges = GetOriginalPrevEdges(ops_[i]);
  1284. auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr<Edge> edge) {
  1285. return edge->prev_operator()->name() == reshape_info->pre_operator_name();
  1286. });
  1287. auto out_edges = GetOriginalNextEdges(ops_[i]);
  1288. auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
  1289. return edge->next_operator()->name() == reshape_info->next_operator_name();
  1290. });
  1291. if (pre_iter != in_edges.end()) {
  1292. MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
  1293. int32_t pre_index = reshape_info->pre_operator_index();
  1294. TensorInfo pre_info;
  1295. if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
  1296. pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
  1297. } else {
  1298. pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
  1299. }
  1300. reshape_info->SetInputLayout(pre_info.tensor_layout());
  1301. Dimensions stra = pre_info.InferStrategy();
  1302. if (stra.empty()) {
  1303. MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
  1304. }
  1305. std::vector<Dimensions> stra_inputs = {stra};
  1306. StrategyPtr reshape_stra =
  1307. std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
  1308. reshape_info->set_strategy(reshape_stra);
  1309. }
  1310. if (next_iter != out_edges.end()) {
  1311. MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
  1312. int32_t next_index = reshape_info->next_operator_index();
  1313. reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout());
  1314. }
  1315. if (reshape_info->Init(nullptr) != SUCCESS) {
  1316. return FAILED;
  1317. }
  1318. }
  1319. }
  1320. return SUCCESS;
  1321. }
  1322. Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
  1323. for (auto &op : ops_) {
  1324. MS_EXCEPTION_IF_NULL(op);
  1325. const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
  1326. if ((output_parameter != 0) && (output_parameter != 1)) {
  1327. MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
  1328. return FAILED;
  1329. }
  1330. }
  1331. return SUCCESS;
  1332. }
  1333. void CostGraph::DFSForTopoOrder(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
  1334. std::vector<OperatorInfoPtr> *topo_order) {
  1335. MS_EXCEPTION_IF_NULL(current_op);
  1336. MS_EXCEPTION_IF_NULL(visited);
  1337. MS_EXCEPTION_IF_NULL(topo_order);
  1338. visited->at(current_op) = true;
  1339. for (const auto &s_edge : current_op->succ_edges()) {
  1340. if (!visited->at(s_edge->next_operator())) {
  1341. DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
  1342. }
  1343. }
  1344. topo_order->push_back(current_op);
  1345. }
  1346. // Compute a topological order of the costgraph
  1347. void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
  1348. std::map<OperatorInfoPtr, bool> visited;
  1349. for (auto &op : ops_) {
  1350. visited[op] = false;
  1351. }
  1352. for (auto &op : ops_) {
  1353. if (!visited[op]) {
  1354. DFSForTopoOrder(op, &visited, topo_order);
  1355. }
  1356. }
  1357. }
  1358. void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) {
  1359. for (auto &op : ops_) {
  1360. auto search = candidate_ops.find(op);
  1361. if (search != candidate_ops.end()) {
  1362. // Mark the critical operators
  1363. op->mark_output_critical();
  1364. // Mark the successive edges
  1365. for (auto &s_edge : op->succ_edges()) {
  1366. s_edge->mark_output_critical();
  1367. }
  1368. } else {
  1369. op->mark_output_not_critical();
  1370. }
  1371. }
  1372. }
  1373. Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
  1374. if (topo_order.size() == 0) {
  1375. MS_LOG(ERROR) << "0 operator in costgraph.";
  1376. return FAILED;
  1377. }
  1378. auto &first_op = topo_order[0];
  1379. if (first_op->prev_edges().size() > 0) {
  1380. MS_LOG(ERROR) << "The first operator in the first of topological order of "
  1381. "costgraph should have 0 incoming edge, but has "
  1382. << first_op->prev_edges() << "edges.";
  1383. return FAILED;
  1384. }
  1385. // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
  1386. // of the output of OperatorInfo that currently has not been used
  1387. std::map<OperatorInfoPtr, int> curr_memory_state;
  1388. (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size())));
  1389. std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state;
  1390. // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
  1391. // not been used
  1392. double curr_memory_size = first_op->GetOutputsTotalSize();
  1393. double max_memory_size = curr_memory_size;
  1394. for (size_t finished = 1; finished < topo_order.size(); ++finished) {
  1395. // Produce
  1396. (void)curr_memory_state.emplace(
  1397. std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size())));
  1398. curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
  1399. // Consume
  1400. for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
  1401. const auto &prev_op = prev_edge->prev_operator();
  1402. curr_memory_state[prev_op]--;
  1403. }
  1404. for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
  1405. const auto &prev_op = prev_edge->prev_operator();
  1406. if (curr_memory_state[prev_op] < 0) {
  1407. MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
  1408. return FAILED;
  1409. } else if (curr_memory_state[prev_op] == 0) {
  1410. curr_memory_state.erase(prev_op);
  1411. curr_memory_size -= prev_op->GetOutputsTotalSize();
  1412. }
  1413. }
  1414. if (curr_memory_size < 0) {
  1415. MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
  1416. }
  1417. // Modify the max
  1418. if (curr_memory_size > max_memory_size) {
  1419. max_memory_size = curr_memory_size;
  1420. max_memory_state = curr_memory_state;
  1421. }
  1422. }
  1423. // Mark those critical operators
  1424. MarkCriticalOpsAndEdges(max_memory_state);
  1425. return SUCCESS;
  1426. }
  1427. Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
  1428. // Two steps to do:
  1429. // 1. Compute a topological order of the costgraph
  1430. // 2. Determine and mark the operators (and necessary edges) that are critical
  1431. std::vector<OperatorInfoPtr> topo_order;
  1432. TopologyOrder(&topo_order);
  1433. std::reverse(std::begin(topo_order), std::end(topo_order));
  1434. if (DetermineCriticalOps(topo_order) != SUCCESS) {
  1435. MS_LOG(ERROR) << "Determining critical operators failed.";
  1436. return FAILED;
  1437. }
  1438. return SUCCESS;
  1439. }
  1440. Status CostGraph::CalculateOpsMemoryCost() {
  1441. for (auto &op : ops_) {
  1442. MS_EXCEPTION_IF_NULL(op);
  1443. if (op->CalculateMemoryCost() != SUCCESS) {
  1444. MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
  1445. return FAILED;
  1446. }
  1447. }
  1448. return SUCCESS;
  1449. }
  1450. Status CostGraph::CalculateOpsMemoryCostForInference() {
  1451. for (auto &op : ops_) {
  1452. MS_EXCEPTION_IF_NULL(op);
  1453. if (op->CalculateMemoryCostForInference() != SUCCESS) {
  1454. MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
  1455. return FAILED;
  1456. }
  1457. }
  1458. return SUCCESS;
  1459. }
  1460. Status CostGraph::CalculateEdgesMemoryCost() {
  1461. for (auto &edge_pair : edges_) {
  1462. const auto &edges = edge_pair.second;
  1463. for (auto &one_edge : edges) {
  1464. if (one_edge->CalculateMemoryCost() != SUCCESS) {
  1465. MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
  1466. return FAILED;
  1467. }
  1468. }
  1469. }
  1470. return SUCCESS;
  1471. }
  1472. Status CostGraph::CalculateEdgesMemoryCostForInference() {
  1473. for (auto &edge_pair : edges_) {
  1474. const auto &edges = edge_pair.second;
  1475. for (auto &one_edge : edges) {
  1476. if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
  1477. MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
  1478. return FAILED;
  1479. }
  1480. }
  1481. }
  1482. return SUCCESS;
  1483. }
  1484. OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
  1485. for (auto one_op : ops_) {
  1486. if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
  1487. if (one_op->refkey_parameter_name() == p_name) {
  1488. return one_op;
  1489. }
  1490. }
  1491. }
  1492. return nullptr;
  1493. }
  1494. Status CostGraph::CorrectOpsMemoryCost() {
  1495. for (auto &one_op : ops_) {
  1496. if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
  1497. if (one_op->GetAliveSuccEdges().size() > 1) {
  1498. // Filter out the case when the TmpIdentity being used by multiple operators
  1499. std::map<size_t, int> output_count;
  1500. for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
  1501. auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
  1502. output_count[output_index]++;
  1503. }
  1504. for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
  1505. auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
  1506. if (output_count[output_index] <= 1) {
  1507. continue;
  1508. }
  1509. auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
  1510. MS_EXCEPTION_IF_NULL(next_op);
  1511. auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
  1512. if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
  1513. MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
  1514. << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
  1515. return FAILED;
  1516. }
  1517. output_count[output_index]--;
  1518. }
  1519. }
  1520. }
  1521. }
  1522. return SUCCESS;
  1523. }
  1524. Status CostGraph::CalculateMemoryCost() {
  1525. if (RUN_PHASE == TRAINING_PHASE) {
  1526. // training phase
  1527. if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
  1528. // Calculate operators' memory usage
  1529. if (CalculateOpsMemoryCost() != SUCCESS) {
  1530. MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
  1531. return FAILED;
  1532. }
  1533. // Calculate edges' memory usage
  1534. if (CalculateEdgesMemoryCost() != SUCCESS) {
  1535. MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
  1536. return FAILED;
  1537. }
  1538. // Correct memory usage caused by TmpIdentity
  1539. if (CorrectOpsMemoryCost() != SUCCESS) {
  1540. MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
  1541. return FAILED;
  1542. }
  1543. } else {
  1544. MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
  1545. return FAILED;
  1546. }
  1547. } else {
  1548. // inference phase
  1549. if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
  1550. // Calculate operators' memory usage
  1551. if (CalculateOpsMemoryCostForInference() != SUCCESS) {
  1552. MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
  1553. return FAILED;
  1554. }
  1555. // Calculate edges's memory usage
  1556. if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
  1557. MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
  1558. return FAILED;
  1559. }
  1560. } else {
  1561. MS_LOG(ERROR) << "Computing operators' critical flag failed.";
  1562. return FAILED;
  1563. }
  1564. }
  1565. return SUCCESS;
  1566. }
  1567. } // namespace parallel
  1568. } // namespace mindspore