You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

somas.cc 71 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #include "backend/optimizer/somas/somas.h"
  14. #include <algorithm>
  15. #include <cstdio>
  16. #include <fstream>
  17. #include <iterator>
  18. #include <memory>
  19. #include <numeric>
  20. #include <set>
  21. #include "backend/optimizer/somas/somas_node.h"
  22. #include "backend/optimizer/somas/somas_solver_pre.h"
  23. #include "backend/optimizer/somas/somas_stream.h"
  24. #include "backend/optimizer/somas/somas_tensor.h"
  25. #ifdef ENABLE_D
  26. #include "runtime/device/ascend/ascend_stream_assign.h"
  27. #endif
  28. #include "backend/optimizer/common/helper.h"
  29. #include "utils/ms_context.h"
  30. #include "debug/common.h"
  31. #ifdef ENABLE_DUMP_IR
  32. #include "debug/rdr/running_data_recorder.h"
  33. #endif
  34. #include "common/thread_pool.h"
  35. #include "profiler/device/ascend/memory_profiling.h"
  36. using mindspore::profiler::ascend::MemoryProfiling;
  37. using mindspore::profiler::ascend::NodeMemory;
  38. using mindspore::profiler::ascend::TensorMemory;
  39. namespace mindspore {
  40. namespace somas {
  41. constexpr auto kGapSize = 512;
  42. constexpr auto kRetryIntervalSeconds = 500;
  43. constexpr size_t kRefNodeTensorNum = 2;
  44. constexpr auto kGraphId = "graph_id";
  45. constexpr auto kHashId = "hash_id";
  46. constexpr auto kMemOffset = "mem_offset";
  47. constexpr auto kNodeSize = "node_size";
  48. constexpr auto kTensorSize = "tensor_size";
  49. constexpr auto kContiguousSize = "contiguous_size";
  50. constexpr auto kRefNodeSize = "ref_node_size";
  51. constexpr auto kStreamSize = "stream_size";
  52. constexpr auto kStreamGroupSize = "stream_group_size";
  53. constexpr auto kTensors = "tensors";
  54. constexpr auto kTensorId = "tensor_id";
  55. constexpr auto kSize = "size";
  56. constexpr auto kOriSize = "ori_size";
  57. constexpr auto kLifelongValue = "lifelong_value";
  58. constexpr auto kLifeStart = "life_start";
  59. constexpr auto kLifeEnd = "life_end";
  60. constexpr auto kOffset = "offset";
  61. constexpr auto kCachedResultThreshold = 2000;
  62. std::map<TensorType, std::string> tensor_type_name_map = {{kCommon, "Common"},
  63. {kOutputOnly, "OutputOnly"},
  64. {kWorkspace, "Workspace"},
  65. {kGetNextOutput, "GetNextOutput"},
  66. {kSummaryInput, "SummaryInput"},
  67. {kRefNodeInput, "RefNodeInput"},
  68. {kRefNodeOutput, "RefNodeOutput"},
  69. {kUnknown, "Unknown"}};
  70. std::map<LifeLongType, std::string> life_long_name_map = {{kLifeLongNone, "LifeLongNone"},
  71. {kLifeLongGraphAll, "LifeLongGraphAll"},
  72. {kLifeLongGraphStart, "LifeLongGraphStart"},
  73. {kLifeLongGraphEnd, "LifeLongGraphEnd"}};
  74. bool Somas::Allocate(const session::KernelGraph *graph) {
  75. auto ret = InitSomasTensors(graph);
  76. if (!ret) {
  77. MS_LOG(EXCEPTION) << "Somas Initialize Failed.";
  78. }
  79. if (tensors_list_.empty()) {
  80. MS_LOG(INFO) << "No Tensor for Somas";
  81. return true;
  82. }
  83. ret = LoadSomasCache(graph);
  84. if (ret) {
  85. GenGraphStatisticInfo();
  86. return ret;
  87. }
  88. // Computing Conflict pairs
  89. MS_LOG(INFO) << "Start Computing Conflict Pairs";
  90. ComputeConflictPairs();
  91. MS_LOG(INFO) << "End Computing Conflict Pairs";
  92. ret = Assign(graph);
  93. if (!ret) {
  94. MS_LOG(EXCEPTION) << "Somas Assign Failed.";
  95. }
  96. SaveSomasResult(graph);
  97. GenGraphStatisticInfo();
  98. return ret;
  99. }
  100. bool Somas::LoadSomasCache(const session::KernelGraph *graph) {
  101. MS_EXCEPTION_IF_NULL(graph);
  102. if (tensors_list_.size() < kCachedResultThreshold) {
  103. MS_LOG(DEBUG) << "Tensors size (" << tensors_list_.size() << ") less than " << kCachedResultThreshold
  104. << ", no need to load cached";
  105. return false;
  106. }
  107. bool ret = CalcSomasModelHash(graph);
  108. if (ret) {
  109. std::string filename = GetSaveGraphsPathName(
  110. "/somas_meta/somas_graph" + std::to_string(graph->graph_id()) + "_" + hash_id_ + ".json", save_graphs_path_);
  111. ret = LoadSomasResult(graph, filename);
  112. if (ret) {
  113. MS_LOG(INFO) << "Load Somas Cache file " << filename << " Successfully.";
  114. }
  115. } else {
  116. MS_LOG(ERROR) << "Calculate somas's model hash id failed.";
  117. }
  118. return ret;
  119. }
  120. bool Somas::CalcSomasModelHash(const session::KernelGraph *graph) {
  121. MS_EXCEPTION_IF_NULL(graph);
  122. auto model_str = SomasInfo(true);
  123. hash_id_ = std::to_string(std::hash<std::string>()(model_str));
  124. MS_LOG(INFO) << "Graph " << graph->graph_id() << "'s SOMAS Model hash id is " << hash_id_;
  125. std::string filename = GetSaveGraphsPathName(
  126. "/somas_meta/somas_graph" + std::to_string(graph->graph_id()) + "_" + hash_id_ + ".info", save_graphs_path_);
  127. return Common::SaveStringToFile(filename, model_str);
  128. }
  129. bool Somas::SaveSomasResult(const session::KernelGraph *graph) {
  130. MS_EXCEPTION_IF_NULL(graph);
  131. if (tensors_list_.size() < kCachedResultThreshold) {
  132. MS_LOG(DEBUG) << "Tensors size (" << tensors_list_.size() << ") less than " << kCachedResultThreshold
  133. << ", no need to save result";
  134. return false;
  135. }
  136. nlohmann::json somas_json;
  137. somas_json[kGraphId] = graph->graph_id();
  138. somas_json[kHashId] = hash_id_;
  139. somas_json[kMemOffset] = mem_offset_;
  140. somas_json[kNodeSize] = nodes_list_.size();
  141. somas_json[kTensorSize] = tensors_list_.size();
  142. somas_json[kContiguousSize] = contiguous_tensors_list_.size();
  143. somas_json[kRefNodeSize] = ref_node_constraints_.size();
  144. somas_json[kStreamSize] = streams_list_.size();
  145. somas_json[kStreamGroupSize] = streams_groups_.size();
  146. std::vector<nlohmann::json> tensors_json;
  147. for (auto &tensor : tensors_list_) {
  148. MS_EXCEPTION_IF_NULL(tensor);
  149. nlohmann::json tensor_json;
  150. tensor_json[kTensorId] = tensor->GetId();
  151. tensor_json[kSize] = tensor->GetAlignedSize();
  152. tensor_json[kOriSize] = tensor->GetOriginalSize();
  153. tensor_json[kLifelongValue] = tensor->lifelong_value_;
  154. tensor_json[kLifeStart] = tensor->lifetime_.start_;
  155. tensor_json[kLifeEnd] = tensor->lifetime_.end_;
  156. tensor_json[kOffset] = tensor->GetOffset();
  157. tensors_json.emplace_back(tensor_json);
  158. }
  159. somas_json[kTensors] = tensors_json;
  160. std::string filename = GetSaveGraphsPathName(
  161. "/somas_meta/somas_graph" + std::to_string(graph->graph_id()) + "_" + hash_id_ + ".json", save_graphs_path_);
  162. (void)Common::SaveStringToFile(filename, somas_json.dump());
  163. return true;
  164. }
  165. bool Somas::LoadSomasResult(const session::KernelGraph *graph, const string &filename) {
  166. if (filename.length() <= strlen(".json")) {
  167. MS_LOG(WARNING) << "please check somas cache file path.";
  168. return false;
  169. }
  170. std::ifstream somas_json_fs(filename);
  171. if (!somas_json_fs.is_open()) {
  172. MS_LOG(INFO) << "Open json file: " << filename << " error, Somas Cache Missed.";
  173. return false;
  174. }
  175. nlohmann::json somas_json;
  176. try {
  177. somas_json_fs >> somas_json;
  178. somas_json_fs.close();
  179. } catch (std::exception &e) {
  180. MS_LOG(WARNING) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
  181. somas_json_fs.close();
  182. std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalSeconds));
  183. std::ifstream retry_tmp(filename);
  184. if (!retry_tmp.is_open()) {
  185. MS_LOG(INFO) << "Open json file: " << filename << " error, please check kernel_meta.";
  186. return false;
  187. }
  188. retry_tmp >> somas_json;
  189. retry_tmp.close();
  190. }
  191. auto ret = VerifySomasResult(graph, somas_json);
  192. if (!ret) {
  193. MS_LOG(WARNING) << "Verify Somas Result Failed.";
  194. return false;
  195. }
  196. auto mem_offset = somas_json[kMemOffset];
  197. mem_offset_ = mem_offset;
  198. ret = UpdateTensorsOffset(somas_json[kTensors]);
  199. return ret;
  200. }
  201. bool Somas::VerifySomasResult(const session::KernelGraph *graph, const nlohmann::json &somas_json) const {
  202. MS_EXCEPTION_IF_NULL(graph);
  203. auto graph_id = somas_json[kGraphId];
  204. auto hash_id = somas_json[kHashId];
  205. auto node_size = somas_json[kNodeSize];
  206. auto tensor_size = somas_json[kTensorSize];
  207. auto contiguous_size = somas_json[kContiguousSize];
  208. auto ref_node_size = somas_json[kRefNodeSize];
  209. auto stream_size = somas_json[kStreamSize];
  210. auto stream_group_size = somas_json[kStreamGroupSize];
  211. if (graph_id != graph->graph_id()) {
  212. MS_LOG(WARNING) << "Mismatch graph id " << graph_id << " vs " << graph->graph_id();
  213. return false;
  214. }
  215. if (hash_id != hash_id_) {
  216. MS_LOG(WARNING) << "Mismatch hash id " << hash_id << " vs " << hash_id_;
  217. return false;
  218. }
  219. if (node_size != nodes_list_.size()) {
  220. MS_LOG(WARNING) << "Mismatch node size " << node_size << " vs " << nodes_list_.size();
  221. return false;
  222. }
  223. if (tensor_size != tensors_list_.size()) {
  224. MS_LOG(WARNING) << "Mismatch tensor size " << tensor_size << " vs " << tensors_list_.size();
  225. return false;
  226. }
  227. if (contiguous_size != contiguous_tensors_list_.size()) {
  228. MS_LOG(WARNING) << "Mismatch contiguous size " << contiguous_size << " vs " << contiguous_tensors_list_.size();
  229. return false;
  230. }
  231. if (ref_node_size != ref_node_constraints_.size()) {
  232. MS_LOG(WARNING) << "Mismatch ref node size " << ref_node_size << " vs " << ref_node_constraints_.size();
  233. return false;
  234. }
  235. if (stream_size != streams_list_.size()) {
  236. MS_LOG(WARNING) << "Mismatch stream size " << stream_size << " vs " << streams_list_.size();
  237. return false;
  238. }
  239. if (stream_group_size != streams_groups_.size()) {
  240. MS_LOG(WARNING) << "Mismatch stream group size " << stream_group_size << " vs " << streams_groups_.size();
  241. return false;
  242. }
  243. return true;
  244. }
  245. bool Somas::UpdateTensorsOffset(const std::vector<nlohmann::json> &tensors_json) {
  246. bool ret = true;
  247. for (auto &tensor_json : tensors_json) {
  248. auto tensor_id = tensor_json[kTensorId];
  249. auto size = tensor_json[kSize];
  250. auto ori_size = tensor_json[kOriSize];
  251. auto lifelong_value = tensor_json[kLifelongValue];
  252. auto life_start = tensor_json[kLifeStart];
  253. auto life_end = tensor_json[kLifeEnd];
  254. auto offset = tensor_json[kOffset];
  255. auto iter = tensors_map_.find(tensor_id);
  256. if (iter != tensors_map_.end()) {
  257. MS_EXCEPTION_IF_NULL(iter->second);
  258. if (size != iter->second->aligned_size_) {
  259. MS_LOG(WARNING) << "Mismatch size of tensor " << tensor_id << " " << size << " vs "
  260. << iter->second->aligned_size_;
  261. ret = false;
  262. break;
  263. }
  264. if (ori_size != iter->second->GetOriginalSize()) {
  265. MS_LOG(WARNING) << "Mismatch original size of tensor " << tensor_id << " " << ori_size << " vs "
  266. << iter->second->GetOriginalSize();
  267. ret = false;
  268. break;
  269. }
  270. if (lifelong_value != iter->second->lifelong_value_) {
  271. MS_LOG(WARNING) << "Mismatch lifelong value of tensor " << tensor_id << " " << lifelong_value << " vs "
  272. << iter->second->lifelong_value_;
  273. ret = false;
  274. break;
  275. }
  276. if (life_start != iter->second->lifetime_.start_) {
  277. MS_LOG(WARNING) << "Mismatch life start of tensor " << tensor_id << " " << life_start << " vs "
  278. << iter->second->lifetime_.start_;
  279. ret = false;
  280. break;
  281. }
  282. if (life_end != iter->second->lifetime_.end_) {
  283. MS_LOG(WARNING) << "Mismatch life start of tensor " << tensor_id << " " << life_end << " vs "
  284. << iter->second->lifetime_.end_;
  285. ret = false;
  286. break;
  287. }
  288. // verify pass, update memory offset
  289. iter->second->offset_ = offset;
  290. } else {
  291. MS_LOG(WARNING) << "Can't find tensor " << tensor_id;
  292. ret = false;
  293. break;
  294. }
  295. }
  296. return ret;
  297. }
  298. bool Somas::InitSomasTensors(const session::KernelGraph *graph) {
  299. MS_EXCEPTION_IF_NULL(graph);
  300. InitBasicInfo(graph);
  301. IndependentNodeOutputProcess(graph);
  302. SummaryInputProcess(graph);
  303. RefNodeProcess(graph);
  304. NonTaskSplitProcess(graph);
  305. UnReuseNodeProcess(graph);
  306. GenContiguousList(graph);
  307. GetNextOutputProcess(graph);
  308. if (tensors_list_.empty()) {
  309. MS_LOG(INFO) << "No Tensor from graph " << graph->graph_id();
  310. return true;
  311. }
  312. MS_LOG(INFO) << "Created " << streams_list_.size() << " streams (" << streams_groups_.size() << " groups), "
  313. << nodes_list_.size() << " nodes, " << tensors_list_.size() << " tensors, and "
  314. << contiguous_tensors_list_.size() << " contiguous lists";
  315. #ifdef ENABLE_DUMP_IR
  316. SubModuleId module = SubModuleId::SM_OPTIMIZER;
  317. std::string name = "somas_pre_processed_info." + std::to_string(graph->graph_id());
  318. (void)mindspore::RDR::RecordString(module, name, SomasInfo());
  319. name = "somas_offline_log." + std::to_string(graph->graph_id());
  320. (void)mindspore::RDR::RecordString(module, name, Offline());
  321. #endif
  322. if (save_graphs_) {
  323. std::string file_path = GetSaveGraphsPathName(
  324. "/somas_pre_processed_info_" + std::to_string(graph->graph_id()) + ".ir", save_graphs_path_);
  325. DumpSomasInfoIR(file_path);
  326. std::string offline_file_path =
  327. GetSaveGraphsPathName("/somas_offline_log_" + std::to_string(graph->graph_id()) + ".ir", save_graphs_path_);
  328. DumpOfflineIR(offline_file_path);
  329. }
  330. return true;
  331. }
  332. void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
  333. MS_EXCEPTION_IF_NULL(graph);
  334. std::vector<CNodePtr> kernel_cnodes;
  335. streams_list_ = {};
  336. nodes_list_ = {};
  337. size_t node_index = 0;
  338. if (graph->subgraph_multi_call()) {
  339. kernel_cnodes = graph->mem_reuse_exec_order();
  340. } else {
  341. kernel_cnodes = graph->execution_order();
  342. }
  343. for (size_t i = 0; i < kernel_cnodes.size(); i++) {
  344. auto kernel = kernel_cnodes[i];
  345. MS_EXCEPTION_IF_NULL(kernel);
  346. SomasStreamPtr stream;
  347. auto stream_id = AnfAlgo::GetStreamId(kernel);
  348. auto it = find_if(streams_list_.begin(), streams_list_.end(),
  349. [stream_id](const SomasStreamPtr &s) { return s->GetId() == stream_id; });
  350. if (it == streams_list_.end()) {
  351. stream = std::make_shared<SomasStream>(stream_id);
  352. streams_list_.push_back(stream);
  353. } else {
  354. stream = *it;
  355. }
  356. // Node
  357. NodeType type = kCommonNode;
  358. if (AnfAlgo::IsCommunicationOp(kernel)) {
  359. type = kCommunicationNode;
  360. }
  361. auto node = std::make_shared<SomasNode>(node_index, type, stream);
  362. MS_EXCEPTION_IF_NULL(node);
  363. node->scope_full_name_ = kernel->fullname_with_scope();
  364. nodes_list_.push_back(node);
  365. stream->nodes_.push_back(node);
  366. auto key = kernel.get();
  367. auto &nodes = nodes_map_[key];
  368. nodes.push_back(node);
  369. node_index++;
  370. }
  371. }
  372. void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph) {
  373. MS_EXCEPTION_IF_NULL(graph);
  374. tensors_list_ = {};
  375. size_t tensor_index = 0;
  376. auto kernel_cnodes = graph->execution_order();
  377. for (const auto &kernel : kernel_cnodes) {
  378. auto nodes = nodes_map_[kernel.get()];
  379. auto node = nodes[0];
  380. MS_EXCEPTION_IF_NULL(node);
  381. auto stream = node->GetStream();
  382. MS_EXCEPTION_IF_NULL(stream);
  383. // Output Tensor
  384. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  385. MS_EXCEPTION_IF_NULL(kernel_mod);
  386. auto output_sizes = kernel_mod->GetOutputSizeList();
  387. auto index = 0;
  388. for (const auto &size : output_sizes) {
  389. auto output_tensor_index = tensor_index;
  390. tensor_index++;
  391. // Set all output tensor lifelong to true.
  392. auto tensor = std::make_shared<SomasTensor>(output_tensor_index, node, stream, size, kLifeLongNone);
  393. MS_EXCEPTION_IF_NULL(tensor);
  394. tensor->lifetime_.start_ = node->GetId();
  395. tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId();
  396. tensor->type_ = kOutputOnly;
  397. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  398. tensor->aligned_size_ = 0;
  399. }
  400. tensors_list_.push_back(tensor);
  401. tensors_map_[output_tensor_index] = tensor;
  402. stream->tensors_.push_back(tensor);
  403. std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) {
  404. MS_EXCEPTION_IF_NULL(node);
  405. node->tensors_.insert(tensor);
  406. node->output_tensors_.push_back(tensor);
  407. });
  408. index++;
  409. }
  410. // WorkSpace Tensor
  411. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  412. index = 0;
  413. for (const auto &size : workspace_sizes) {
  414. auto workspace_tensor_index = tensor_index;
  415. tensor_index++;
  416. SomasTensorPtr tensor = std::make_shared<SomasTensor>(workspace_tensor_index, node, stream, size, kLifeLongNone);
  417. MS_EXCEPTION_IF_NULL(tensor);
  418. tensor->type_ = kWorkspace;
  419. tensor->lifetime_.start_ = node->GetId();
  420. tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId();
  421. if (AnfAlgo::WorkspaceAddrExist(kernel, index)) {
  422. tensor->aligned_size_ = 0;
  423. }
  424. tensors_list_.push_back(tensor);
  425. tensors_map_[workspace_tensor_index] = tensor;
  426. stream->tensors_.push_back(tensor);
  427. std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) {
  428. MS_EXCEPTION_IF_NULL(node);
  429. node->tensors_.insert(tensor);
  430. node->workspace_tensors_.push_back(tensor);
  431. });
  432. index++;
  433. }
  434. }
  435. }
  436. void Somas::InitSomasInputTensors(const session::KernelGraph *graph) {
  437. MS_EXCEPTION_IF_NULL(graph);
  438. bool is_all_nop_node = opt::IsAllNopNode(graph);
  439. static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
  440. auto kernel_cnodes = graph->execution_order();
  441. for (const auto &kernel : kernel_cnodes) {
  442. if (AnfAlgo::GetCNodeName(kernel) != kAtomicAddrCleanOpName) {
  443. InitCommonNodeInputs(is_all_nop_node, kernel);
  444. } else {
  445. InitAtomicCleanInputs(enable_fusion_clear, kernel);
  446. }
  447. }
  448. }
  449. void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
  450. auto nodes = nodes_map_[kernel.get()];
  451. auto node = nodes[0];
  452. MS_EXCEPTION_IF_NULL(node);
  453. auto stream = node->GetStream();
  454. MS_EXCEPTION_IF_NULL(stream);
  455. // Input Tensor
  456. auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
  457. size_t real_input_index = 0;
  458. for (size_t i = 0; i < input_tensor_num; i++) {
  459. auto input_node = kernel->input(i + 1);
  460. MS_EXCEPTION_IF_NULL(input_node);
  461. session::KernelWithIndex prenode_index;
  462. if (is_all_nop_node) {
  463. prenode_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  464. } else {
  465. prenode_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
  466. }
  467. if (AnfAlgo::CheckPrimitiveType(prenode_index.first, prim::kPrimMakeTuple)) {
  468. MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
  469. }
  470. MS_EXCEPTION_IF_NULL(prenode_index.first);
  471. if (!AnfAlgo::IsRealCNodeKernel(prenode_index.first)) {
  472. auto op_name = AnfAlgo::GetCNodeName(kernel);
  473. TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel, i);
  474. if ((op_name == kDynamicRNNOpName || op_name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
  475. continue;
  476. }
  477. auto parameter = GetSomasParameter(prenode_index.first, prenode_index.second);
  478. node->input_parameters_map_[real_input_index] = parameter;
  479. real_input_index++;
  480. MS_LOG(DEBUG) << "Input [" << prenode_index.first->fullname_with_scope() << "] is not a real cnode kernel.";
  481. continue;
  482. }
  483. auto iter = nodes_map_.find(prenode_index.first.get());
  484. if (iter == nodes_map_.end()) {
  485. MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " ["
  486. << prenode_index.first->fullname_with_scope() << "] is not init.";
  487. }
  488. auto pre_somas_node = iter->second.at(0);
  489. if (prenode_index.second > pre_somas_node->output_tensors_.size()) {
  490. MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node ["
  491. << prenode_index.first->fullname_with_scope() << "]'s outputs size "
  492. << pre_somas_node->output_tensors_.size();
  493. }
  494. auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second];
  495. MS_EXCEPTION_IF_NULL(input_somas_tensor);
  496. std::for_each(nodes.begin(), nodes.end(),
  497. [input_somas_tensor](auto &node) { node->input_tensors_.push_back(input_somas_tensor); });
  498. real_input_index++;
  499. if (input_somas_tensor->type_ == kOutputOnly) {
  500. input_somas_tensor->type_ = kCommon;
  501. }
  502. input_somas_tensor->destinationStreams_.insert(stream);
  503. for (auto &repeat_node : nodes) {
  504. input_somas_tensor->destinations_.insert(repeat_node);
  505. if (input_somas_tensor->lifetime_.end_ < repeat_node->GetId()) {
  506. input_somas_tensor->lifetime_.end_ = repeat_node->GetId();
  507. }
  508. }
  509. if (node != pre_somas_node) {
  510. node->ancestor_nodes_.insert(pre_somas_node);
  511. }
  512. auto input_tensor_stream = input_somas_tensor->GetSourceStream();
  513. if (input_tensor_stream != stream) {
  514. stream->ancestor_streams_.insert(input_tensor_stream);
  515. input_somas_tensor->between_streams_ = true;
  516. }
  517. }
  518. }
  519. void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) {
  520. auto node = nodes_map_[kernel.get()].at(0);
  521. MS_EXCEPTION_IF_NULL(node);
  522. auto stream = node->GetStream();
  523. MS_EXCEPTION_IF_NULL(stream);
  524. auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
  525. for (size_t i = 0; i < input_tensor_num; i++) {
  526. MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]);
  527. auto pre_node = kernel->input(i + 1)->cast<CNodePtr>();
  528. auto iter = nodes_map_.find(pre_node.get());
  529. if (iter == nodes_map_.end()) {
  530. MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input ["
  531. << pre_node->fullname_with_scope() << "] is not init.";
  532. }
  533. auto pre_somas_node = iter->second.at(0);
  534. MS_EXCEPTION_IF_NULL(pre_somas_node);
  535. // set clean output tensors
  536. if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
  537. auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
  538. for (auto index : clean_output_indexs) {
  539. if (index > pre_somas_node->output_tensors_.size()) {
  540. MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope()
  541. << "]'s outputs size " << pre_somas_node->output_tensors_.size();
  542. }
  543. auto input_somas_tensor = pre_somas_node->output_tensors_[index];
  544. MS_EXCEPTION_IF_NULL(input_somas_tensor);
  545. node->input_tensors_.push_back(input_somas_tensor);
  546. if (enable_fusion_clear) {
  547. input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
  548. MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
  549. << " 's output" << index << " to lifelong";
  550. }
  551. }
  552. }
  553. // set clean workspace tensors
  554. if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
  555. auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
  556. for (const auto &index : clean_workspace_indexs) {
  557. if (index > pre_somas_node->output_tensors_.size()) {
  558. MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope()
  559. << "]'s Workspace size " << pre_somas_node->workspace_tensors_.size();
  560. }
  561. auto input_somas_tensor = pre_somas_node->workspace_tensors_[index];
  562. MS_EXCEPTION_IF_NULL(input_somas_tensor);
  563. node->input_tensors_.push_back(input_somas_tensor);
  564. if (enable_fusion_clear) {
  565. input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
  566. MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
  567. << " 's workspace" << index << " to lifelong";
  568. }
  569. }
  570. }
  571. }
  572. }
  573. SomasParameterPtr Somas::CreateSomasParameter(AnfNodePtr node, size_t index) {
  574. MS_EXCEPTION_IF_NULL(node);
  575. auto id = parameters_list_.size();
  576. auto device_addr = AnfAlgo::GetOutputAddr(node, index);
  577. if (device_addr == nullptr) {
  578. MS_LOG(EXCEPTION) << "Node " << node->fullname_with_scope() << " has no device address before Somas.";
  579. }
  580. auto param = std::make_shared<SomasParameter>(id, node->fullname_with_scope(), index, device_addr->GetPtr(),
  581. device_addr->GetSize());
  582. parameters_list_.push_back(param);
  583. return param;
  584. }
  585. SomasParameterPtr Somas::GetSomasParameter(AnfNodePtr node, size_t index) {
  586. auto key = node.get();
  587. auto iter = parameters_map_.find(key);
  588. if (iter != parameters_map_.end()) {
  589. auto it = std::find_if(iter->second.begin(), iter->second.end(),
  590. [index](SomasParameterPtr param) -> bool { return index == param->output_index_; });
  591. if (it != iter->second.end()) {
  592. return *it;
  593. } else {
  594. auto new_param = CreateSomasParameter(node, index);
  595. iter->second.push_back(new_param);
  596. return new_param;
  597. }
  598. } else {
  599. auto param = CreateSomasParameter(node, index);
  600. parameters_map_[key].push_back(param);
  601. return param;
  602. }
  603. }
  604. void Somas::InitBasicInfo(const session::KernelGraph *graph) {
  605. MS_EXCEPTION_IF_NULL(graph);
  606. #ifdef ENABLE_D
  607. streams_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group();
  608. #endif
  609. InitSomasStreamAndNode(graph);
  610. InitSomasOutputAndWorkspaceTensors(graph);
  611. InitSomasInputTensors(graph);
  612. auto context_ptr = MsContext::GetInstance();
  613. MS_EXCEPTION_IF_NULL(context_ptr);
  614. #ifdef ENABLE_DUMP_IR
  615. SubModuleId module = SubModuleId::SM_OPTIMIZER;
  616. std::string name = "somas_initial_info." + std::to_string(graph->graph_id());
  617. (void)mindspore::RDR::RecordString(module, name, SomasInfo());
  618. #endif
  619. save_graphs_ = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  620. save_graphs_path_ = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
  621. if (save_graphs_path_.empty()) {
  622. save_graphs_path_ = ".";
  623. }
  624. if (save_graphs_) {
  625. std::string file_path =
  626. GetSaveGraphsPathName("/somas_initial_info_" + std::to_string(graph->graph_id()) + ".ir", save_graphs_path_);
  627. DumpSomasInfoIR(file_path);
  628. }
  629. }
  630. void Somas::GetNextOutputProcess(const session::KernelGraph *graph) {
  631. MS_EXCEPTION_IF_NULL(graph);
  632. auto kernel_cnodes = graph->execution_order();
  633. size_t total_size = 0;
  634. for (const auto &kernel : kernel_cnodes) {
  635. if (AnfAlgo::GetCNodeName(kernel) != kGetNextOpName) {
  636. continue;
  637. }
  638. auto iter = nodes_map_.find(kernel.get());
  639. if (iter != nodes_map_.end()) {
  640. auto &node = iter->second.at(0);
  641. MS_EXCEPTION_IF_NULL(node);
  642. auto getnext_output_tensors = node->output_tensors_;
  643. for (auto &tensor : getnext_output_tensors) {
  644. MS_EXCEPTION_IF_NULL(tensor);
  645. total_size += tensor->GetAlignedSize();
  646. tensor->lifelong_value_ = kLifeLongGraphAll;
  647. tensor->type_ = kGetNextOutput;
  648. }
  649. }
  650. }
  651. MS_LOG(INFO) << "Special Tensor total size: GetNext Output " << total_size;
  652. }
  653. void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) {
  654. MS_EXCEPTION_IF_NULL(graph);
  655. auto kernel_cnodes = graph->execution_order();
  656. size_t total_size = 0;
  657. for (const auto &kernel : kernel_cnodes) {
  658. bool independent = AnfAlgo::IsIndependentNode(kernel);
  659. if (!independent) {
  660. continue;
  661. }
  662. auto iter = nodes_map_.find(kernel.get());
  663. if (iter != nodes_map_.end()) {
  664. auto &node = iter->second.at(0);
  665. MS_EXCEPTION_IF_NULL(node);
  666. auto semi_reuse_output_tensors = node->output_tensors_;
  667. for (auto &tensor : semi_reuse_output_tensors) {
  668. MS_EXCEPTION_IF_NULL(tensor);
  669. total_size += tensor->GetAlignedSize();
  670. tensor->lifelong_value_ = kLifeLongGraphAll;
  671. }
  672. }
  673. }
  674. MS_LOG(INFO) << "Special Tensor total size: Independent Node output " << total_size;
  675. }
  676. void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
  677. MS_EXCEPTION_IF_NULL(graph);
  678. bool summary_exist = graph->summary_node_exist();
  679. if (!summary_exist) {
  680. return;
  681. }
  682. auto summary_nodes = graph->summary_nodes();
  683. if (summary_nodes.empty()) {
  684. return;
  685. }
  686. size_t total_summary_size = 0;
  687. for (auto &node_item : summary_nodes) {
  688. auto node = node_item.second.first;
  689. size_t index = IntToSize(node_item.second.second);
  690. auto iter = nodes_map_.find(node.get());
  691. if (iter != nodes_map_.end()) {
  692. auto input_node = iter->second.at(0);
  693. MS_EXCEPTION_IF_NULL(input_node);
  694. if (index < input_node->output_tensors_.size()) {
  695. auto tensor = input_node->output_tensors_[index];
  696. MS_EXCEPTION_IF_NULL(tensor);
  697. tensor->lifelong_value_ = kLifeLongGraphAll;
  698. tensor->type_ = kSummaryInput;
  699. total_summary_size += tensor->GetAlignedSize();
  700. MS_LOG(INFO) << "Set summary node input tensor's lifelong, node: " << node->fullname_with_scope()
  701. << " index: " << index;
  702. } else {
  703. MS_LOG(WARNING) << "Index exceed size, node " << node->fullname_with_scope() << " index: " << index
  704. << " size: " << input_node->output_tensors_.size();
  705. }
  706. } else {
  707. MS_LOG(WARNING) << "Can't find summary input node " << node->fullname_with_scope() << " index: " << index;
  708. }
  709. }
  710. MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
  711. }
  712. void Somas::RefNodeProcess(const session::KernelGraph *graph) {
  713. MS_EXCEPTION_IF_NULL(graph);
  714. auto kernel_cnodes = graph->execution_order();
  715. size_t total_output_size = 0;
  716. size_t total_input_size = 0;
  717. for (const auto &kernel : kernel_cnodes) {
  718. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  719. if (kernel_mod == nullptr) {
  720. MS_LOG(WARNING) << "Kernel mode is NULL Of " << kernel->fullname_with_scope();
  721. continue;
  722. }
  723. auto output_sizes = kernel_mod->GetOutputSizeList();
  724. size_t output_index = 0;
  725. for (const auto &size : output_sizes) {
  726. auto out_index = output_index;
  727. output_index++;
  728. session::AnfWithOutIndex out_pair(kernel, out_index);
  729. if (graph->IsInRefOutputMap(out_pair)) {
  730. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  731. MS_EXCEPTION_IF_NULL(origin_pair.first);
  732. auto &node = nodes_map_[kernel.get()].at(0);
  733. MS_EXCEPTION_IF_NULL(node);
  734. auto output_tensor = node->output_tensors_[out_index];
  735. MS_EXCEPTION_IF_NULL(output_tensor);
  736. output_tensor->type_ = kRefNodeOutput;
  737. total_output_size += size;
  738. if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) {
  739. auto ori_node = origin_pair.first->cast<CNodePtr>();
  740. auto ori_index = origin_pair.second;
  741. if (nodes_map_.find(ori_node.get()) == nodes_map_.end()) {
  742. MS_LOG(EXCEPTION)
  743. << "The ori_node is not included in nodes_map_ constructed from exec_order of graph. Info ori_node: "
  744. << ori_node->DebugString();
  745. }
  746. auto &repeat_node = nodes_map_[ori_node.get()].at(0);
  747. MS_EXCEPTION_IF_NULL(repeat_node);
  748. auto input_tensor = repeat_node->output_tensors_[ori_index];
  749. MS_EXCEPTION_IF_NULL(input_tensor);
  750. input_tensor->type_ = kRefNodeInput;
  751. total_input_size += input_tensor->aligned_size_;
  752. std::vector<size_t> refnode_input_output;
  753. refnode_input_output.push_back(input_tensor->GetId());
  754. refnode_input_output.push_back(output_tensor->GetId());
  755. ref_node_constraints_.push_back(refnode_input_output);
  756. MS_LOG(INFO) << "RefNode: input " << input_tensor->GetId() << " output " << output_tensor->GetId();
  757. }
  758. }
  759. }
  760. }
  761. MS_LOG(INFO) << "Special Tensor total size: RefNode: input " << total_input_size << " output " << total_output_size;
  762. }
  763. void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) {
  764. MS_EXCEPTION_IF_NULL(graph);
  765. auto kernel_cnodes = graph->execution_order();
  766. for (const auto &kernel : kernel_cnodes) {
  767. auto op_name = AnfAlgo::GetCNodeName(kernel);
  768. if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) {
  769. std::vector<size_t> refnode_input_output;
  770. auto node = nodes_map_[kernel.get()].at(0);
  771. MS_EXCEPTION_IF_NULL(node);
  772. if (node->input_tensors_.size() == 0) {
  773. MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process.";
  774. }
  775. auto input_tensor = node->input_tensors_[0];
  776. MS_EXCEPTION_IF_NULL(input_tensor);
  777. input_tensor->type_ = kRefNodeInput;
  778. refnode_input_output.push_back(input_tensor->GetId());
  779. for (auto &output_tensor : node->output_tensors_) {
  780. MS_EXCEPTION_IF_NULL(output_tensor);
  781. output_tensor->type_ = kRefNodeOutput;
  782. refnode_input_output.push_back(output_tensor->GetId());
  783. }
  784. ref_node_constraints_.push_back(refnode_input_output);
  785. }
  786. }
  787. }
  788. void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) {
  789. MS_EXCEPTION_IF_NULL(graph);
  790. vector<string> full_name_list = {};
  791. if (full_name_list.size() == 0) {
  792. return;
  793. }
  794. auto kernel_cnodes = graph->execution_order();
  795. for (const auto &kernel : kernel_cnodes) {
  796. MS_EXCEPTION_IF_NULL(kernel);
  797. auto full_name = kernel->fullname_with_scope();
  798. auto iter = std::find(full_name_list.begin(), full_name_list.end(), full_name);
  799. if (iter != full_name_list.end()) {
  800. MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name;
  801. auto key = kernel.get();
  802. auto somas_node = nodes_map_[key].at(0);
  803. MS_EXCEPTION_IF_NULL(somas_node);
  804. // input
  805. auto inputs = somas_node->input_tensors_;
  806. for (auto &input : inputs) {
  807. MS_EXCEPTION_IF_NULL(input);
  808. input->lifelong_value_ = kLifeLongGraphAll;
  809. }
  810. // output
  811. auto outputs = somas_node->output_tensors_;
  812. MS_LOG(INFO) << "Output size of " << kernel->fullname_with_scope() << " is " << outputs.size();
  813. for (auto &output : outputs) {
  814. MS_EXCEPTION_IF_NULL(output);
  815. output->lifelong_value_ = kLifeLongGraphAll;
  816. }
  817. // workspace
  818. auto workspaces = somas_node->workspace_tensors_;
  819. for (auto &workspace : workspaces) {
  820. MS_EXCEPTION_IF_NULL(workspace);
  821. workspace->lifelong_value_ = kLifeLongGraphAll;
  822. }
  823. }
  824. }
  825. }
  826. void Somas::GenContiguousList(const session::KernelGraph *graph) {
  827. MS_EXCEPTION_IF_NULL(graph);
  828. for (const auto &node : nodes_list_) {
  829. MS_EXCEPTION_IF_NULL(node);
  830. if (node->GetType() != kCommunicationNode) {
  831. continue;
  832. }
  833. // Contiguous input
  834. if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) {
  835. if (node->input_tensors_[0]->aligned_size_) {
  836. node->input_tensors_[0]->aligned_size_ += kGapSize;
  837. }
  838. if (node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_) {
  839. node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += kGapSize;
  840. }
  841. std::vector<size_t> inputs;
  842. for (const auto &input_tensor : node->input_tensors_) {
  843. MS_EXCEPTION_IF_NULL(input_tensor);
  844. comm_input_total_size_ += input_tensor->aligned_size_;
  845. input_tensor->contiguous_ = true;
  846. inputs.push_back(input_tensor->GetId());
  847. }
  848. contiguous_tensors_list_.push_back(inputs);
  849. }
  850. // Contiguous output
  851. if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) {
  852. if (node->output_tensors_[0]->aligned_size_) {
  853. node->output_tensors_[0]->aligned_size_ += kGapSize;
  854. }
  855. if (node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_) {
  856. node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += kGapSize;
  857. }
  858. std::vector<size_t> outputs;
  859. for (const auto &output_tensor : node->output_tensors_) {
  860. MS_EXCEPTION_IF_NULL(output_tensor);
  861. comm_output_total_size_ += output_tensor->aligned_size_;
  862. output_tensor->contiguous_ = true;
  863. outputs.push_back(output_tensor->GetId());
  864. }
  865. contiguous_tensors_list_.push_back(outputs);
  866. }
  867. }
  868. }
  869. void Somas::ComputeConflictPairs() {
  870. if (tensors_list_.empty()) {
  871. MS_LOG(INFO) << "No Tensor for Conflict computing";
  872. return;
  873. }
  874. MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
  875. auto start_conflict = std::chrono::system_clock::now();
  876. std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
  877. UpdateTensorDestinations();
  878. MS_LOG(INFO) << "Start Bitset";
  879. std::vector<DynamicBitSet> nodes_dependency;
  880. size_t count = nodes_list_.back()->GetId() + 1;
  881. for (size_t i = 0; i < count; i++) {
  882. nodes_dependency.emplace_back(count);
  883. }
  884. MS_LOG(INFO) << "Start Path Computing";
  885. // Loop to compute ancestor paths via bitset for time dependence
  886. for (const auto &node : nodes_list_) {
  887. for (const auto &ancestor : node->ancestor_nodes_) {
  888. nodes_dependency[node->GetId()].SetBitTrue(ancestor->GetId());
  889. Union(&nodes_dependency[node->GetId()], &nodes_dependency[ancestor->GetId()]);
  890. }
  891. }
  892. MS_LOG(INFO) << "End Path Computing";
  893. MS_LOG(INFO) << "Start Tensor Relation Computing";
  894. count = tensors_list_.back()->GetId() + 1;
  895. for (size_t i = 0; i < count; i++) {
  896. reuse_matrix_.emplace_back(count);
  897. }
  898. if (tensors_list_.size() < kParallelComputeSizeThreshold) {
  899. ComputeMultiTensorConflicts(tensors_list_, tensors_list_, nodes_dependency, &reuse_matrix_);
  900. } else {
  901. MS_LOG(INFO) << "Tensor Num " << tensors_list_.size() << " is larger than " << kParallelComputeSizeThreshold;
  902. MS_LOG(INFO) << "Enter Multi-Thread Mode...";
  903. size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
  904. MS_LOG(INFO) << "Threads Num is " << process_num;
  905. size_t start_index = 0;
  906. size_t total_size = tensors_list_.size();
  907. size_t job_size = total_size / process_num;
  908. if (job_size == 0) {
  909. job_size = total_size;
  910. }
  911. std::vector<common::Task> tasks;
  912. while (start_index < total_size) {
  913. size_t end_index = (start_index + job_size) > total_size ? total_size : start_index + job_size;
  914. auto jobs = std::vector<SomasTensorPtr>(tensors_list_.begin() + start_index, tensors_list_.begin() + end_index);
  915. auto task = [this, jobs, &nodes_dependency]() {
  916. this->ComputeMultiTensorConflicts(jobs, tensors_list_, nodes_dependency, &reuse_matrix_);
  917. return common::SUCCESS;
  918. };
  919. tasks.emplace_back(task);
  920. start_index += job_size;
  921. }
  922. common::ThreadPool::GetInstance().SyncRun(tasks);
  923. }
  924. MS_LOG(INFO) << "End Tensor Relation Computing";
  925. auto end_conflict = std::chrono::system_clock::now();
  926. MS_LOG(INFO) << "End Conflict Computing (Bitset Model)(time taken "
  927. << std::chrono::duration_cast<std::chrono::milliseconds>(end_conflict - start_conflict).count() << "ms)";
  928. }
  929. void Somas::UpdateTensorDestinations() {
  930. // Loop to add edges within each stream (node order within stream)
  931. for (const auto &stream : streams_list_) {
  932. MS_EXCEPTION_IF_NULL(stream);
  933. auto &nodes = stream->nodes_;
  934. std::sort(nodes.begin(), nodes.end(), NodeSort);
  935. for (size_t i = 1; i < nodes.size(); i++) {
  936. const auto &previous_node = nodes[i - 1];
  937. const auto &current_node = nodes[i];
  938. MS_EXCEPTION_IF_NULL(current_node);
  939. current_node->ancestor_nodes_.insert(previous_node);
  940. }
  941. }
  942. // Loop to add edges from end to beginning of next group
  943. for (const auto &group : streams_groups_) {
  944. for (size_t i = 1; i < group.size(); i++) {
  945. int64_t previous_stream = group[i - 1];
  946. int64_t current_stream = group[i];
  947. auto it =
  948. std::find_if(streams_list_.begin(), streams_list_.end(),
  949. [previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; });
  950. if (it == streams_list_.end()) {
  951. continue;
  952. }
  953. auto &last_node_in_prev_stream = (*it)->nodes_.back();
  954. it = std::find_if(streams_list_.begin(), streams_list_.end(),
  955. [current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; });
  956. if (it == streams_list_.end()) {
  957. continue;
  958. }
  959. auto &first_node_in_cur_stream = (*it)->nodes_.front();
  960. first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream);
  961. }
  962. }
  963. // Loop to avoid tensors with empty destinations (add itself)
  964. for (const auto &tensor : tensors_list_) {
  965. MS_EXCEPTION_IF_NULL(tensor);
  966. if (tensor->destinations_.size() == 0) {
  967. tensor->destinations_.insert(tensor->GetSourceNode());
  968. }
  969. }
  970. // Loop to compute max destinations in each stream
  971. for (const auto &tensor : tensors_list_) {
  972. MS_EXCEPTION_IF_NULL(tensor);
  973. tensor->ComputeMaxDestinationId();
  974. }
  975. }
  976. void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
  977. const std::vector<SomasTensorPtr> &all_tensors_list,
  978. const vector<DynamicBitSet> &nodes_dependency,
  979. std::vector<DynamicBitSet> *tensor_relation) const {
  980. auto start = std::chrono::system_clock::now();
  981. MS_LOG(INFO) << "Start Computing Conflicts Pairs, tensors list size is " << calc_tensors_list.size();
  982. for (size_t i = 0; i < calc_tensors_list.size(); i++) {
  983. auto calc_tensor = calc_tensors_list[i];
  984. MS_EXCEPTION_IF_NULL(calc_tensor);
  985. if (calc_tensor->IsLifelong() || calc_tensor->IsSemiLifelongEnd() || calc_tensor->IsRefOverlap() ||
  986. calc_tensor->GetAlignedSize() == 0) {
  987. continue;
  988. }
  989. ComputeOneTensorConflicts(calc_tensor, all_tensors_list, nodes_dependency, tensor_relation);
  990. }
  991. auto end = std::chrono::system_clock::now();
  992. MS_LOG(INFO) << "End Computing Conflicts Pairs (time taken "
  993. << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms)";
  994. }
  995. void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor,
  996. const std::vector<SomasTensorPtr> &all_tensors_list,
  997. const vector<DynamicBitSet> &nodes_dependency,
  998. std::vector<DynamicBitSet> *tensor_relation) const {
  999. MS_EXCEPTION_IF_NULL(calc_tensor);
  1000. for (size_t j = 0; j < all_tensors_list.size(); j++) {
  1001. auto target_tensor = all_tensors_list[j];
  1002. MS_EXCEPTION_IF_NULL(target_tensor);
  1003. if (calc_tensor == target_tensor || target_tensor->IsLifelong() || target_tensor->IsSemiLifelongStart() ||
  1004. target_tensor->IsRefOverlap() || target_tensor->GetAlignedSize() == 0) {
  1005. continue;
  1006. }
  1007. size_t calc_src_node = calc_tensor->GetSourceNode()->GetId();
  1008. size_t target_src_node = target_tensor->GetSourceNode()->GetId();
  1009. if (calc_src_node == target_src_node) {
  1010. continue;
  1011. }
  1012. if ((*tensor_relation)[calc_tensor->GetId()].IsBitTrue(target_tensor->GetId()) ||
  1013. (*tensor_relation)[target_tensor->GetId()].IsBitTrue(calc_tensor->GetId())) {
  1014. continue;
  1015. }
  1016. bool reuse = true;
  1017. // check calc_tensor's all consumers is target_tensor's source node's dependency or not
  1018. for (const auto &dst_map : calc_tensor->max_destinations_) {
  1019. const auto &dst_node = dst_map.second;
  1020. MS_EXCEPTION_IF_NULL(dst_node);
  1021. if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) {
  1022. // calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or
  1023. // not when target_tensor produced
  1024. reuse = false;
  1025. break;
  1026. } else if (target_src_node == dst_node->GetId()) {
  1027. // calc_tensor is target_tensor's source node's input, can't reuse
  1028. reuse = false;
  1029. break;
  1030. } else {
  1031. // calc_tensor's consumer is in target_tensor's source node's dependency, this consumer is done when
  1032. // target_tensor produced
  1033. reuse = true;
  1034. }
  1035. }
  1036. if (reuse) {
  1037. // calc_tensor and target_tensor have dependencies so they can reuse each other
  1038. (*tensor_relation)[calc_tensor->GetId()].SetBitTrue(target_tensor->GetId());
  1039. (*tensor_relation)[target_tensor->GetId()].SetBitTrue(calc_tensor->GetId());
  1040. }
  1041. }
  1042. }
  1043. bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); }
  1044. bool Somas::Assign(const session::KernelGraph *graph) {
  1045. if (tensors_list_.empty()) {
  1046. MS_LOG(INFO) << "No Tensor for Assigner";
  1047. return true;
  1048. }
  1049. // Ref Node Preprocessing
  1050. UpdateRefTensorsConflict();
  1051. std::map<size_t, size_t> contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor();
  1052. vector<vector<size_t>> contiguous_tensors_list_removed = contiguous_tensors_list_;
  1053. std::set<vector<size_t>> contiguous_tensors_list_to_remove;
  1054. for (auto ref_list_pair : contiguous_list_with_ref_index_map) {
  1055. contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]);
  1056. }
  1057. // remove the contiguous list which all tensors' align size is 0
  1058. for (auto contiguous_list : contiguous_tensors_list_) {
  1059. bool all_outputs = true;
  1060. for (auto tensor_id : contiguous_list) {
  1061. auto tensor = tensors_list_[tensor_id];
  1062. MS_EXCEPTION_IF_NULL(tensor);
  1063. if (tensor->aligned_size_ != 0) {
  1064. all_outputs = false;
  1065. break;
  1066. }
  1067. }
  1068. if (all_outputs) {
  1069. contiguous_tensors_list_to_remove.insert(contiguous_list);
  1070. }
  1071. }
  1072. for (auto contiguous_list : contiguous_tensors_list_to_remove) {
  1073. auto iterator =
  1074. std::find(contiguous_tensors_list_removed.begin(), contiguous_tensors_list_removed.end(), contiguous_list);
  1075. if (iterator != contiguous_tensors_list_removed.end()) {
  1076. contiguous_tensors_list_removed.erase(iterator);
  1077. } else {
  1078. MS_LOG(WARNING) << "Could not find contiguous list to remove for ref";
  1079. }
  1080. }
  1081. MS_LOG(INFO) << "End Solving Preprocessing for Ref Node";
  1082. UpdateRefOverlapTensorsConflicts();
  1083. #ifdef SOMAS_DEBUG
  1084. // Compute number of constraints for each tensor
  1085. auto tensors_num = tensors_list_.size();
  1086. for (auto tensor1 : tensors_list_) {
  1087. auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum();
  1088. tensor1->num_constraints_ = tensors_num - ones_num;
  1089. }
  1090. #endif
  1091. // Prepare solver info
  1092. MS_LOG(INFO) << "Start Loop to create solver info";
  1093. for (auto tensor : tensors_list_) {
  1094. MS_EXCEPTION_IF_NULL(tensor);
  1095. if (tensor->GetSolverTensorDesc() != nullptr) {
  1096. SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc();
  1097. solver_tensor_desc_map_.insert(std::pair<size_t, SomasSolverTensorDescPtr>(pSolverTensor->index_, pSolverTensor));
  1098. }
  1099. }
  1100. MS_LOG(INFO) << "End Loop to create solver info";
  1101. MS_LOG(INFO) << "Start Solving";
  1102. if (solver_tensor_desc_map_.empty()) {
  1103. MS_LOG(INFO) << "solver_tensor_desc_list is empty.";
  1104. return true;
  1105. }
  1106. somas_solver_ = std::make_shared<SomasSolverPre>();
  1107. auto status =
  1108. somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed, false);
  1109. MS_LOG(INFO) << "End Solving";
  1110. if (status != SUCCESS) {
  1111. GenGraphStatisticInfo();
  1112. MS_LOG(EXCEPTION) << "SOMAS Solving Failed.";
  1113. }
  1114. // Update solver_tensor_desc offset to tensors list
  1115. for (const auto &tensor : tensors_list_) {
  1116. MS_EXCEPTION_IF_NULL(tensor);
  1117. tensor->SetOffset();
  1118. }
  1119. UpdateRefTensorsOffset();
  1120. UpdateContiguousTensorsOffset(contiguous_list_with_ref_index_map);
  1121. // Set mem_offset_ value by solver result
  1122. mem_offset_ = static_cast<size_t>(somas_solver_->GetMaxOffset());
  1123. return true;
  1124. }
  1125. std::map<size_t, size_t> Somas::GetContiguousListContainRefTensor() {
  1126. // key: contiguous list index with ref node input; value: contiguous list index with ref node output
  1127. std::map<size_t, size_t> contiguous_list_with_ref_index_map;
  1128. std::map<size_t, size_t> ref_tensors_in_contiguous_map = GetRefTensorsInContiguousList();
  1129. std::map<size_t, std::map<size_t, std::set<size_t>>> contiguous_ref_list_error_check_map;
  1130. for (auto ref_pair : ref_tensors_in_contiguous_map) {
  1131. size_t ref_first = ref_pair.first;
  1132. size_t ref_second = ref_pair.second;
  1133. bool found_first = false;
  1134. bool found_second = false;
  1135. size_t index_first = 0;
  1136. size_t index_second = 0;
  1137. size_t index_in_list_first = 0;
  1138. size_t index_in_list_second = 0;
  1139. for (size_t index = 0; index < contiguous_tensors_list_.size() && (!found_first || !found_second); index++) {
  1140. if (!found_first) {
  1141. auto iterator_first =
  1142. std::find(contiguous_tensors_list_[index].begin(), contiguous_tensors_list_[index].end(), ref_first);
  1143. if (iterator_first != contiguous_tensors_list_[index].end()) {
  1144. index_first = index;
  1145. index_in_list_first = iterator_first - contiguous_tensors_list_[index].begin();
  1146. found_first = true;
  1147. }
  1148. }
  1149. if (!found_second) {
  1150. auto iterator_second =
  1151. std::find(contiguous_tensors_list_[index].begin(), contiguous_tensors_list_[index].end(), ref_second);
  1152. if (iterator_second != contiguous_tensors_list_[index].end()) {
  1153. index_second = index;
  1154. index_in_list_second = iterator_second - contiguous_tensors_list_[index].begin();
  1155. found_second = true;
  1156. }
  1157. }
  1158. }
  1159. if (!found_first) {
  1160. MS_LOG(WARNING) << "Contiguous ref tensor " << ref_first << " not found in any contiguous list";
  1161. }
  1162. if (!found_second) {
  1163. MS_LOG(WARNING) << "Contiguous ref tensor " << ref_second << " not found in any contiguous list";
  1164. }
  1165. if (contiguous_list_with_ref_index_map.find(index_first) == contiguous_list_with_ref_index_map.end() ||
  1166. contiguous_list_with_ref_index_map[index_first] == index_second) {
  1167. contiguous_list_with_ref_index_map[index_first] = index_second;
  1168. // Checking for error cases
  1169. if (index_in_list_first != index_in_list_second) {
  1170. MS_LOG(WARNING) << "Inconsistency in contiguous ref: tensor " << ref_first << " in position "
  1171. << index_in_list_first << " of contiguous list " << index_first << " and tensor " << ref_second
  1172. << " in position " << index_in_list_second << " of contiguous list " << index_second;
  1173. }
  1174. contiguous_ref_list_error_check_map[index_first][index_second].insert(index_in_list_first);
  1175. } else {
  1176. MS_LOG(WARNING) << "Contiguous list " << index_first << " associated (ref node) with two other contiguous lists: "
  1177. << contiguous_list_with_ref_index_map[index_first] << " and " << index_second;
  1178. }
  1179. }
  1180. for (auto check_list_pair : contiguous_ref_list_error_check_map) {
  1181. auto first_list = check_list_pair.first;
  1182. auto index_set_map = check_list_pair.second;
  1183. for (auto index_set : index_set_map) {
  1184. auto second_list = index_set.first;
  1185. if (contiguous_tensors_list_[first_list].size() != contiguous_tensors_list_[second_list].size()) {
  1186. MS_LOG(WARNING) << "Contiguous lists " << first_list << " and " << second_list
  1187. << " considered in ref do not have the same size";
  1188. }
  1189. for (size_t x = 0; x < contiguous_tensors_list_[second_list].size(); x++) {
  1190. if (contiguous_ref_list_error_check_map[first_list][second_list].count(x) == 0) {
  1191. MS_LOG(WARNING) << "Contiguous lists " << first_list << " and " << second_list
  1192. << " considered in ref: ref pair at in-lists index " << x << " has not been considered";
  1193. }
  1194. }
  1195. }
  1196. }
  1197. return contiguous_list_with_ref_index_map;
  1198. }
  1199. std::map<size_t, size_t> Somas::GetRefTensorsInContiguousList() {
  1200. // key: refnode input value: refnode output
  1201. std::map<size_t, size_t> ref_tensors_in_contiguous_map;
  1202. for (auto ref_node_list : ref_node_constraints_) {
  1203. // Count contiguous tensors in ref list
  1204. size_t contiguous_in_ref_list = std::count_if(ref_node_list.begin(), ref_node_list.end(),
  1205. [this](size_t tid) { return tensors_map_[tid]->contiguous_; });
  1206. // Keep info about contiguous and check for errors
  1207. if (ref_node_list.size() > kRefNodeTensorNum && contiguous_in_ref_list > 0) {
  1208. MS_LOG(WARNING) << "Ref node of size greater than two with at least one contiguous tensor in";
  1209. }
  1210. if (ref_node_list.size() == kRefNodeTensorNum && contiguous_in_ref_list == 1) {
  1211. MS_LOG(WARNING) << "Ref node of size two with only one contiguous tensor" << ref_node_list[0] << ":"
  1212. << tensors_map_[ref_node_list[0]]->contiguous_ << ", " << ref_node_list[1] << ":"
  1213. << tensors_map_[ref_node_list[1]]->contiguous_;
  1214. }
  1215. if (ref_node_list.size() == kRefNodeTensorNum && contiguous_in_ref_list == kRefNodeTensorNum) {
  1216. ref_tensors_in_contiguous_map[ref_node_list[0]] = ref_node_list[1];
  1217. }
  1218. }
  1219. return ref_tensors_in_contiguous_map;
  1220. }
  1221. void Somas::UpdateContiguousTensorsOffset(const std::map<size_t, size_t> &contiguous_ref_list_map) {
  1222. // Handle contiguous ref node
  1223. for (auto ref_list_pair : contiguous_ref_list_map) {
  1224. size_t index_first = ref_list_pair.first;
  1225. size_t index_second = ref_list_pair.second;
  1226. for (size_t x = 0; x < contiguous_tensors_list_[index_second].size(); x++) {
  1227. tensors_map_[contiguous_tensors_list_[index_second][x]]->offset_ =
  1228. tensors_map_[contiguous_tensors_list_[index_first][x]]->offset_;
  1229. }
  1230. }
  1231. // Contiguous gaps postprocessing
  1232. for (auto list : contiguous_tensors_list_) {
  1233. tensors_map_[list[0]]->offset_ += kGapSize;
  1234. }
  1235. }
  1236. void Somas::UpdateRefTensorsOffset() {
  1237. // Ref Node Postprocessing
  1238. MS_LOG(INFO) << "\nStart Solving Postprocessing for Ref Node";
  1239. // Set offset for rest of ref node list (ignored by solver due to ref node preprocessing)
  1240. for (auto ref_node_list : ref_node_constraints_) {
  1241. for (size_t i = 1; i < ref_node_list.size(); ++i) {
  1242. tensors_map_[ref_node_list[i]]->offset_ = tensors_map_[ref_node_list[0]]->offset_;
  1243. }
  1244. }
  1245. }
  1246. void Somas::UpdateRefOverlapTensorsConflicts() {
  1247. // Ref Overlap Preprocessing
  1248. MS_LOG(INFO) << "Start Solving Preprocessing for Ref Overlap";
  1249. // In ConflictComputing(), by use of ref_overlap_ flag, each tensor in a ref_overlap_list has all entries 1 in
  1250. // cannot_reuse_ array Here, we allow reuse only among tensors in same list
  1251. for (auto ref_overlap_list : ref_overlap_constraints_) {
  1252. for (size_t tid_1 : ref_overlap_list) {
  1253. for (size_t tid_2 : ref_overlap_list) {
  1254. reuse_matrix_[tid_1].SetBitTrue(tid_2);
  1255. reuse_matrix_[tid_2].SetBitTrue(tid_1);
  1256. }
  1257. }
  1258. }
  1259. MS_LOG(INFO) << "End Solving Preprocessing for Ref Overlap";
  1260. }
  1261. void Somas::UpdateRefTensorsConflict() {
  1262. // Keep all constraints for first tensor in list
  1263. for (auto ref_node_list : ref_node_constraints_) {
  1264. size_t tid_0 = ref_node_list[0];
  1265. for (SomasTensorPtr tensor : tensors_list_) {
  1266. if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) {
  1267. continue;
  1268. }
  1269. for (size_t tid : ref_node_list) {
  1270. if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) {
  1271. reuse_matrix_[tid_0].SetBitFalse(tensor->GetId());
  1272. reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0);
  1273. break;
  1274. }
  1275. }
  1276. }
  1277. // Set rest to size 0, so that solver ignores them (if not contiguous)
  1278. for (size_t i = 1; i < ref_node_list.size(); ++i) {
  1279. if (!tensors_map_[ref_node_list[i]]->contiguous_) {
  1280. tensors_map_[ref_node_list[i]]->aligned_size_ = 0;
  1281. }
  1282. }
  1283. }
  1284. }
  1285. std::string Somas::GetSplitName(const std::string &scope_name) const {
  1286. auto index = scope_name.rfind('/');
  1287. if (index == std::string::npos) {
  1288. return scope_name;
  1289. } else {
  1290. if (index < scope_name.size() - 1) {
  1291. auto split_name = scope_name.substr(index + 1);
  1292. return split_name;
  1293. }
  1294. return scope_name;
  1295. }
  1296. }
  1297. std::string Somas::SomasInfo(bool calc_hash) const {
  1298. std::ostringstream oss;
  1299. if (!calc_hash) {
  1300. DumpParameters(oss);
  1301. }
  1302. DumpTensors(oss);
  1303. DumpNodes(oss);
  1304. oss << "\n\nAll Stream Groups:\n\n";
  1305. for (const auto &stream_group : streams_groups_) {
  1306. for (const auto &stream : stream_group) {
  1307. oss << "stm" << stream << " ";
  1308. }
  1309. oss << "\n";
  1310. }
  1311. if (!ref_node_constraints_.empty()) {
  1312. oss << "\n\nAll Ref Node Info:\n\n";
  1313. for (const auto &ref_in_out : ref_node_constraints_) {
  1314. oss << "refnode input-output:";
  1315. for (const auto &item : ref_in_out) {
  1316. oss << "%" << item << "T ";
  1317. }
  1318. oss << "\n";
  1319. }
  1320. }
  1321. return oss.str();
  1322. }
  1323. void Somas::DumpNodes(std::ostringstream &oss) const {
  1324. oss << "\n\nAll Nodes:\n\n";
  1325. for (const auto &node : nodes_list_) {
  1326. MS_EXCEPTION_IF_NULL(node);
  1327. auto scope_name = node->scope_full_name_;
  1328. std::string split_name = GetSplitName(scope_name);
  1329. oss << "$" << node->GetId() << "\t" << split_name << "\t" << static_cast<int>(node->GetType()) << "\t";
  1330. auto input_num = node->input_tensors_.size() + node->input_parameters_map_.size();
  1331. oss << "inputs[";
  1332. size_t tensor_index = 0;
  1333. for (size_t input_index = 0; input_index < input_num; input_index++) {
  1334. auto iter = node->input_parameters_map_.find(input_index);
  1335. if (iter != node->input_parameters_map_.end()) {
  1336. oss << "%" << iter->second->id_ << "P"
  1337. << ", ";
  1338. } else {
  1339. oss << "%" << node->input_tensors_[tensor_index]->GetId() << "T"
  1340. << ", ";
  1341. tensor_index++;
  1342. }
  1343. }
  1344. oss << "]";
  1345. oss << "\toutputs[";
  1346. for (const auto &out : node->output_tensors_) {
  1347. MS_EXCEPTION_IF_NULL(out);
  1348. oss << "%" << out->GetId() << "T"
  1349. << ", ";
  1350. }
  1351. oss << "]";
  1352. oss << "\tworkspace[";
  1353. for (const auto &wk : node->workspace_tensors_) {
  1354. MS_EXCEPTION_IF_NULL(wk);
  1355. oss << "%" << wk->GetId() << "T"
  1356. << ", ";
  1357. }
  1358. oss << "]";
  1359. oss << "\tstreamID["
  1360. << "@" << node->GetStream()->GetId() << "]\n";
  1361. }
  1362. }
  1363. void Somas::DumpTensors(std::ostringstream &oss) const {
  1364. oss << "\n\nAll Tensors:\n\n";
  1365. oss << "index:"
  1366. << "\tsize:"
  1367. << "\treal_size:"
  1368. << "\toffset:"
  1369. << "\taddr:"
  1370. << "\ttype:"
  1371. << "\tlifelong:"
  1372. << "\tlife_start:"
  1373. << "\tlife_end:"
  1374. << "\tsource node name:\n";
  1375. for (const auto &tensor : tensors_list_) {
  1376. MS_EXCEPTION_IF_NULL(tensor);
  1377. auto scope_name = tensor->GetSourceNode()->scope_full_name_;
  1378. std::string split_name = GetSplitName(scope_name);
  1379. oss << "%" << tensor->GetId() << "T"
  1380. << "\t"
  1381. << "#" << tensor->GetAlignedSize() << "S"
  1382. << "\t"
  1383. << "#" << tensor->GetOriginalSize() << "S"
  1384. << "\t"
  1385. << "&" << tensor->GetOffset() << ""
  1386. << "\t"
  1387. << "&" << static_cast<void *>(tensor->GetOffset() + mem_base_addr_) << "\t"
  1388. << tensor_type_name_map[tensor->type_] << "\t" << tensor->IsLifelong() << "\t" << tensor->lifetime_.start_
  1389. << "\t" << tensor->lifetime_.end_ << "\t" << split_name << "\n";
  1390. }
  1391. }
  1392. void Somas::DumpParameters(std::ostringstream &oss) const {
  1393. oss << "All Parameters:\n\n";
  1394. oss << "index:"
  1395. << "\tsize:"
  1396. << "\tstart_addr:"
  1397. << "\tsource node name:"
  1398. << "\tnode out index:\n";
  1399. for (const auto &param : parameters_list_) {
  1400. MS_EXCEPTION_IF_NULL(param);
  1401. oss << "%" << param->id_ << "P"
  1402. << "\t"
  1403. << "#" << param->size_ << "S"
  1404. << "\t"
  1405. << "&" << param->addr_ << "\t" << param->source_node_name_ << "\t" << param->output_index_ << "\n";
  1406. }
  1407. }
  1408. void Somas::DumpSomasInfoIR(const string filename) const { (void)Common::SaveStringToFile(filename, SomasInfo()); }
  1409. std::string Somas::Offline() const {
  1410. std::ostringstream oss;
  1411. for (auto tensor : tensors_list_) {
  1412. MS_EXCEPTION_IF_NULL(tensor);
  1413. if (tensor->IsOutputOnly() || tensor->type_ == TensorType::kRefNodeOutput) {
  1414. oss << "Somas EDGE ERROR src=n" << tensor->GetSourceNode()->GetId()
  1415. << ", srcstm=" << tensor->GetSourceStream()->GetId() << ", dst=nc"
  1416. << ", dststm=nc"
  1417. << ", workspace=0, size=" << tensor->GetOriginalSize()
  1418. << ", lifelong=" << static_cast<int>(tensor->lifelong_value_) << ", tid=" << tensor->GetId()
  1419. << ", start=" << tensor->lifetime_.start_ << ", end=" << tensor->lifetime_.end_ << std::endl;
  1420. } else {
  1421. std::map<size_t, size_t> dest_infos;
  1422. for (SomasNodePtr dest_node : tensor->destinations_) {
  1423. dest_infos.insert(std::make_pair(dest_node->GetId(), dest_node->GetStream()->GetId()));
  1424. }
  1425. for (auto dest_info : dest_infos) {
  1426. oss << "Somas EDGE src=n" << tensor->GetSourceNode()->GetId()
  1427. << ", srcstm=" << tensor->GetSourceStream()->GetId() << ", dst=n" << dest_info.first
  1428. << ", dststm=" << dest_info.second << ", workspace=" << static_cast<int>(tensor->type_ == kWorkspace)
  1429. << ", size=" << tensor->GetOriginalSize() << ", lifelong=" << static_cast<int>(tensor->lifelong_value_)
  1430. << ", tid=" << tensor->GetId() << ", start=" << tensor->lifetime_.start_
  1431. << ", end=" << tensor->lifetime_.end_ << std::endl;
  1432. }
  1433. }
  1434. }
  1435. for (vector<size_t> tList : contiguous_tensors_list_) {
  1436. oss << "Somas CONTIGUOUS";
  1437. for (size_t tid : tList) {
  1438. oss << " " << tid;
  1439. }
  1440. oss << std::endl;
  1441. }
  1442. for (const auto &group : streams_groups_) {
  1443. oss << "Somas GROUP";
  1444. for (int64_t sid : group) {
  1445. oss << " " << sid;
  1446. }
  1447. oss << std::endl;
  1448. }
  1449. return oss.str();
  1450. }
  1451. void Somas::DumpOfflineIR(const string filename) const {
  1452. MS_LOG(INFO) << "Printing somas-log-from-graph log: " << filename;
  1453. (void)Common::SaveStringToFile(filename, Offline());
  1454. }
  1455. std::string Somas::SomasMemory() const {
  1456. std::ostringstream oss;
  1457. std::map<size_t, size_t> mem_map;
  1458. for (auto tensor : tensors_list_) {
  1459. MS_EXCEPTION_IF_NULL(tensor);
  1460. mem_map[tensor->GetOffset()] = 0;
  1461. }
  1462. size_t num = 0;
  1463. for (auto iter = mem_map.begin(); iter != mem_map.end(); ++iter, ++num) {
  1464. iter->second = num;
  1465. }
  1466. std::map<size_t, std::map<size_t, SomasTensorPtr>> mem_list;
  1467. for (const auto &output_tensor : tensors_list_) {
  1468. MS_EXCEPTION_IF_NULL(output_tensor);
  1469. size_t key = output_tensor->offset_;
  1470. auto iter = mem_list.find(key);
  1471. if (iter == mem_list.end()) {
  1472. std::map<size_t, SomasTensorPtr> id_tensor_map;
  1473. id_tensor_map[output_tensor->GetId()] = output_tensor;
  1474. mem_list[key] = id_tensor_map;
  1475. } else {
  1476. iter->second[output_tensor->GetId()] = output_tensor;
  1477. }
  1478. }
  1479. oss << "mem_id:"
  1480. << "\tstart_offset:"
  1481. << "\tend_offset:"
  1482. << "\ttensor_id:"
  1483. << "\torigin_size:"
  1484. << "\talign_size:"
  1485. << "\tstart_addr:"
  1486. << "\tend_addr:"
  1487. << "\ttype:"
  1488. << "\tsrc_node:"
  1489. << "\tsrc_stm_id:"
  1490. << "lifetime_start\t"
  1491. << "lifetime_end\n";
  1492. for (const auto &mem : mem_list) {
  1493. auto id_tensor_map = mem.second;
  1494. for (const auto &id_tensor : id_tensor_map) {
  1495. auto place_tensor = id_tensor.second;
  1496. MS_EXCEPTION_IF_NULL(place_tensor);
  1497. std::string scope_name;
  1498. size_t src_stm_id = 0xffff;
  1499. if (place_tensor->GetSourceNode() != nullptr) {
  1500. scope_name = place_tensor->GetSourceNode()->scope_full_name_;
  1501. src_stm_id = place_tensor->GetSourceNode()->GetStream()->GetId();
  1502. } else {
  1503. scope_name = "Somas Tensor";
  1504. }
  1505. std::string split_name = GetSplitName(scope_name);
  1506. oss << "#" << mem_map[place_tensor->GetOffset()] << "\t" << place_tensor->GetOffset() << "\t"
  1507. << place_tensor->GetOffset() + place_tensor->GetAlignedSize() << "\t%" << place_tensor->GetId() << "T\t"
  1508. << place_tensor->GetOriginalSize() << "\t" << place_tensor->GetAlignedSize() << "\t&"
  1509. << static_cast<void *>(place_tensor->GetOffset() + mem_base_addr_) << "\t&"
  1510. << static_cast<void *>(place_tensor->GetOffset() + mem_base_addr_ + place_tensor->GetAlignedSize()) << "\t"
  1511. << tensor_type_name_map[place_tensor->type_] << "\t" << split_name << "\tstm" << src_stm_id << "\t"
  1512. << place_tensor->lifetime_.start_ << "\t" << place_tensor->lifetime_.end_ << "\n";
  1513. }
  1514. }
  1515. return oss.str();
  1516. }
  1517. void Somas::DumpSomasMemoryIR(const string filename) const { (void)Common::SaveStringToFile(filename, SomasMemory()); }
  1518. size_t Somas::CalcLowerBound() const {
  1519. size_t max_node_id = std::accumulate(tensors_list_.begin(), tensors_list_.end(), 0, [](size_t max_id, auto tensor) {
  1520. return std::max(max_id, tensor->lifetime_.end_);
  1521. });
  1522. std::map<size_t, size_t> lifetime_lb;
  1523. for (size_t time = 0; time <= max_node_id; time++) {
  1524. lifetime_lb[time] = 0;
  1525. }
  1526. size_t lower, upper;
  1527. for (auto tensor : tensors_list_) {
  1528. MS_EXCEPTION_IF_NULL(tensor);
  1529. if (tensor->lifelong_value_ == kLifeLongGraphAll) {
  1530. lower = 0;
  1531. upper = max_node_id;
  1532. } else {
  1533. lower = tensor->lifetime_.start_;
  1534. upper = tensor->lifetime_.end_;
  1535. }
  1536. for (size_t time = lower; time <= upper; time++) {
  1537. lifetime_lb[time] += tensor->GetAlignedSize();
  1538. }
  1539. }
  1540. size_t max_lifetime = 0;
  1541. for (size_t time = 0; time <= max_node_id; time++) {
  1542. if (max_lifetime < lifetime_lb[time]) {
  1543. max_lifetime = lifetime_lb[time];
  1544. }
  1545. }
  1546. return max_lifetime;
  1547. }
  1548. void Somas::GenGraphStatisticInfo() {
  1549. lower_bound_ = CalcLowerBound();
  1550. for (const auto &tensor : tensors_list_) {
  1551. MS_EXCEPTION_IF_NULL(tensor);
  1552. upper_bound_ += tensor->aligned_size_;
  1553. if (tensor->type_ == kWorkspace) {
  1554. workspace_total_size_ += tensor->aligned_size_;
  1555. }
  1556. if (tensor->lifelong_value_ == kLifeLongGraphAll) {
  1557. lifelong_all_total_size_ += tensor->aligned_size_;
  1558. } else if (tensor->lifelong_value_ == kLifeLongGraphStart) {
  1559. lifelong_start_total_size_ += tensor->aligned_size_;
  1560. } else if (tensor->lifelong_value_ == kLifeLongGraphEnd) {
  1561. lifelong_end_total_size_ += tensor->aligned_size_;
  1562. }
  1563. }
  1564. const double giga = 1024. * 1024. * 1024.;
  1565. MS_LOG(INFO) << "Lower Bound: " << lower_bound_ << " (" << lower_bound_ / giga
  1566. << " GB), Upper Bound: " << upper_bound_ << " (" << upper_bound_ / giga << " GB)";
  1567. MS_LOG(INFO) << "\nTotal Dynamic Size (Upper Bound):\t" << upper_bound_ << "\n"
  1568. << "Theoretical Optimal Size (Lower Bound):\t" << lower_bound_ << "\n"
  1569. << "Total Workspace Size:\t" << workspace_total_size_ << "\n"
  1570. << "Total Communication Input Tensor Size:\t" << comm_input_total_size_ << "\n"
  1571. << "Total Communication Output Tensor Size:\t" << comm_output_total_size_ << "\n"
  1572. << "Total LifeLong All Tensor Size:\t" << lifelong_all_total_size_ << "\n"
  1573. << "Total LifeLong Start Tensor Size:\t" << lifelong_start_total_size_ << "\n"
  1574. << "Total LifeLong End Tensor Size:\t" << lifelong_end_total_size_ << "\n"
  1575. << "Reused Size(Allocate Size):\t" << GetTotalMemSize() << "\n\n\n";
  1576. }
  1577. uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
  1578. MS_EXCEPTION_IF_NULL(node);
  1579. auto key = node.get();
  1580. auto iter = nodes_map_.find(key);
  1581. uint8_t *ptr = nullptr;
  1582. if (iter != nodes_map_.end()) {
  1583. auto &somas_node = iter->second.at(0);
  1584. MS_EXCEPTION_IF_NULL(somas_node);
  1585. if (index >= somas_node->output_tensors_.size()) {
  1586. MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's output size:["
  1587. << somas_node->output_tensors_.size() << "]";
  1588. }
  1589. auto output_tensor = somas_node->output_tensors_[index];
  1590. ptr = mem_base_addr_ + output_tensor->offset_;
  1591. } else {
  1592. MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map";
  1593. }
  1594. return ptr;
  1595. }
  1596. uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const {
  1597. MS_EXCEPTION_IF_NULL(node);
  1598. auto key = node.get();
  1599. auto iter = nodes_map_.find(key);
  1600. uint8_t *ptr = nullptr;
  1601. if (iter != nodes_map_.end()) {
  1602. auto &somas_node = iter->second.at(0);
  1603. MS_EXCEPTION_IF_NULL(somas_node);
  1604. if (index >= somas_node->workspace_tensors_.size()) {
  1605. MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:["
  1606. << somas_node->workspace_tensors_.size() << "]";
  1607. }
  1608. auto workspace_tensor = somas_node->workspace_tensors_[index];
  1609. ptr = mem_base_addr_ + workspace_tensor->offset_;
  1610. }
  1611. return ptr;
  1612. }
  1613. void Somas::ConvertToProfilingNode(uint32_t graph_id) {
  1614. #ifdef ENABLE_D
  1615. auto graph_node = MemoryProfiling::GetInstance().GetGraphMemoryNode(graph_id);
  1616. if (graph_node == nullptr) {
  1617. graph_node = MemoryProfiling::GetInstance().AddGraphMemoryNode(graph_id);
  1618. MS_LOG(INFO) << "Add graph memory node for dynamic memory profiling, graph id is " << graph_id;
  1619. }
  1620. for (const auto &tensor : tensors_list_) {
  1621. TensorMemory tensor_memory;
  1622. tensor_memory.SetTensorId(tensor->GetId());
  1623. tensor_memory.SetAlignedSize(tensor->GetAlignedSize());
  1624. tensor_memory.SetType(tensor_type_name_map[tensor->type_]);
  1625. tensor_memory.SetLifeStart(tensor->lifetime_.start_);
  1626. tensor_memory.SetLifeEnd(tensor->lifetime_.end_);
  1627. tensor_memory.SetLifeLong(life_long_name_map[tensor->lifelong_value_]);
  1628. graph_node->AddTensorMemory(tensor_memory);
  1629. }
  1630. for (const auto &node : nodes_list_) {
  1631. NodeMemory node_memory;
  1632. std::string name = GetSplitName(node->scope_full_name_);
  1633. node_memory.SetNodeName(name);
  1634. node_memory.SetNodeId(node->GetId());
  1635. for (const auto &input_tensor : node->input_tensors_) {
  1636. node_memory.AddInputTensorId(input_tensor->GetId());
  1637. }
  1638. for (const auto &output_tensor : node->output_tensors_) {
  1639. node_memory.AddOutputTensorId(output_tensor->GetId());
  1640. }
  1641. for (const auto &workspace_tensor : node->workspace_tensors_) {
  1642. node_memory.AddWorkSpaceTensorId(workspace_tensor->GetId());
  1643. }
  1644. graph_node->AddNodeMemory(node_memory);
  1645. }
  1646. #endif
  1647. }
  1648. } // namespace somas
  1649. } // namespace mindspore