You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime.cc 52 kB

5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/kernel_runtime.h"
  17. #include <functional>
  18. #include <utility>
  19. #include <vector>
  20. #include "backend/optimizer/common/helper.h"
  21. #include "backend/session/anf_runtime_algorithm.h"
  22. #include "backend/session/kernel_graph.h"
  23. #include "common/trans.h"
  24. #include "debug/data_dump/dump_json_parser.h"
  25. #include "frontend/operator/ops.h"
  26. #include "ir/value.h"
  27. #include "utils/ms_context.h"
  28. #include "utils/ms_utils.h"
  29. #include "utils/shape_utils.h"
  30. #include "utils/utils.h"
  31. #include "frontend/parallel/context.h"
  32. #include "debug/env_config_parser.h"
  33. #if (ENABLE_CPU && !_WIN32)
  34. #include "ps/ps_cache/ps_cache_manager.h"
  35. #endif
  36. using mindspore::kernel::Address;
  37. using mindspore::kernel::AddressPtr;
  38. namespace mindspore {
  39. namespace device {
  40. KernelRuntime::~KernelRuntime() {}
  41. bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; }
  42. bool KernelRuntime::LoadData(session::KernelGraph *graph) { return false; }
  43. bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
  44. MS_EXCEPTION_IF_NULL(kernel);
  45. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  46. const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
  47. MS_EXCEPTION_IF_NULL(address);
  48. return address->DeviceType() == GetTargetDeviceAddressType();
  49. }
  50. return false;
  51. }
  52. void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
  53. auto context_ptr = MsContext::GetInstance();
  54. MS_EXCEPTION_IF_NULL(context_ptr);
  55. MS_EXCEPTION_IF_NULL(mem_manager_);
  56. mem_manager_->ResetDynamicMemory();
  57. AssignStaticMemory(graph);
  58. AssignDynamicMemory(graph);
  59. UpdateRefNodeOutputMem(graph);
  60. }
  61. void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  62. session::KernelGraph *graph) {
  63. MS_EXCEPTION_IF_NULL(graph);
  64. MS_EXCEPTION_IF_NULL(mem_manager_);
  65. mem_manager_->ResetDynamicMemory();
  66. RunOpAssignInputMemory(input_tensors, graph);
  67. AssignStaticMemoryValueNode(graph);
  68. for (const auto &cnode : graph->execution_order()) {
  69. RunOpAssignOutputMemory(cnode);
  70. RunOpAssignWorkSpaceMemory(cnode);
  71. }
  72. UpdateRefNodeOutputMem(graph);
  73. }
  74. void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) const {
  75. MS_EXCEPTION_IF_NULL(graph);
  76. // clear input parameter memory resource
  77. for (const auto &input_node : graph->inputs()) {
  78. MS_EXCEPTION_IF_NULL(input_node);
  79. AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
  80. }
  81. // clear input value node memory resource
  82. for (const auto &value_node : graph->graph_value_nodes()) {
  83. MS_EXCEPTION_IF_NULL(value_node);
  84. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  85. }
  86. for (const auto &cnode : graph->execution_order()) {
  87. MS_EXCEPTION_IF_NULL(cnode);
  88. // clear output memory resource
  89. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  90. for (size_t index = 0; index < output_num; ++index) {
  91. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  92. }
  93. // clear workspace memory resource
  94. auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
  95. MS_EXCEPTION_IF_NULL(kernel_mod);
  96. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  97. for (size_t index = 0; index < workspace_lists.size(); ++index) {
  98. AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
  99. }
  100. }
  101. }
  102. bool KernelRuntime::DumpDataEnabled() {
  103. auto &dump_json_parser = DumpJsonParser::GetInstance();
  104. return dump_json_parser.e2e_dump_enabled();
  105. }
  106. bool KernelRuntime::DumpDataEnabledIteration() {
  107. auto &dump_json_parser = DumpJsonParser::GetInstance();
  108. if (!dump_json_parser.e2e_dump_enabled()) {
  109. return false;
  110. }
  111. auto cur_iter = dump_json_parser.cur_dump_iter() + 1;
  112. if (dump_json_parser.iteration() != 0) {
  113. return cur_iter == dump_json_parser.iteration();
  114. }
  115. return true;
  116. }
  117. void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) {
  118. AssignStaticMemoryInput(graph);
  119. AssignStaticMemoryValueNode(graph);
  120. AssignStaticMemoryOutput(graph);
  121. }
  122. void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  123. const session::KernelGraph *graph) {
  124. MS_EXCEPTION_IF_NULL(graph);
  125. MS_EXCEPTION_IF_NULL(mem_manager_);
  126. if (input_tensors.size() != graph->inputs().size()) {
  127. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
  128. << " should be equal to graph input parameter size " << graph->inputs().size();
  129. }
  130. for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) {
  131. auto item = graph->inputs()[input_index];
  132. MS_EXCEPTION_IF_NULL(item);
  133. if (!item->isa<Parameter>()) {
  134. continue;
  135. }
  136. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  137. for (size_t index = 0; index < output_size; index++) {
  138. MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
  139. auto output_address =
  140. std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address());
  141. if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
  142. AnfAlgo::SetOutputAddr(output_address, index, item.get());
  143. continue;
  144. }
  145. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  146. if (output_type_id == kTypeUnknown) {
  147. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  148. }
  149. auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  150. auto device_address =
  151. CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  152. MS_EXCEPTION_IF_NULL(device_address);
  153. MS_EXCEPTION_IF_NULL(mem_manager_);
  154. auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
  155. if (!ret) {
  156. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
  157. }
  158. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  159. }
  160. }
  161. }
  162. void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
  163. MS_EXCEPTION_IF_NULL(kernel);
  164. MS_EXCEPTION_IF_NULL(mem_manager_);
  165. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  166. MS_EXCEPTION_IF_NULL(kernel_mod);
  167. auto output_sizes = kernel_mod->GetOutputSizeList();
  168. if (output_sizes.empty()) {
  169. return;
  170. }
  171. for (size_t i = 0; i < output_sizes.size(); ++i) {
  172. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  173. continue;
  174. }
  175. if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
  176. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  177. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  178. continue;
  179. }
  180. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  181. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  182. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  183. device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
  184. MS_EXCEPTION_IF_NULL(device_address);
  185. auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
  186. if (!ret) {
  187. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
  188. }
  189. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  190. }
  191. }
  192. void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
  193. MS_EXCEPTION_IF_NULL(kernel);
  194. MS_EXCEPTION_IF_NULL(mem_manager_);
  195. if (kernel->isa<CNode>()) {
  196. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  197. MS_EXCEPTION_IF_NULL(kernel_mod);
  198. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  199. for (size_t i = 0; i < workspace_lists.size(); ++i) {
  200. auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
  201. MS_EXCEPTION_IF_NULL(device_address);
  202. auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
  203. if (!ret) {
  204. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << workspace_lists[i];
  205. }
  206. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  207. }
  208. }
  209. }
  210. void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph) {
  211. if (pre_output_value == nullptr) {
  212. return;
  213. }
  214. std::vector<tensor::TensorPtr> pre_output_tensors;
  215. TensorValueToTensor(pre_output_value, &pre_output_tensors);
  216. MS_EXCEPTION_IF_NULL(graph);
  217. auto output_nodes = graph->outputs();
  218. if (pre_output_tensors.size() != output_nodes.size()) {
  219. MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
  220. << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
  221. }
  222. // share output address with pre output tensors
  223. for (size_t i = 0; i < output_nodes.size(); ++i) {
  224. auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
  225. if (!output_node_with_index.first->isa<CNode>()) {
  226. if (output_node_with_index.first->isa<Parameter>()) {
  227. auto param = output_node_with_index.first->cast<ParameterPtr>();
  228. if (!param->has_default()) {
  229. MS_LOG(EXCEPTION) << "The output parameter should be real parameter!";
  230. }
  231. }
  232. continue;
  233. }
  234. auto real_output_cnode = output_node_with_index.first->cast<CNodePtr>();
  235. MS_EXCEPTION_IF_NULL(real_output_cnode);
  236. MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
  237. if (pre_output_tensors[i]->device_address() == nullptr) {
  238. MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
  239. continue;
  240. }
  241. if (opt::IsNopNode(real_output_cnode)) {
  242. if (real_output_cnode->inputs().size() < 2) {
  243. MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
  244. << " should large than one!";
  245. }
  246. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
  247. output_node_with_index.second, real_output_cnode->input(1).get());
  248. } else {
  249. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
  250. output_node_with_index.second, output_node_with_index.first.get());
  251. }
  252. }
  253. }
  254. void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
  255. MS_EXCEPTION_IF_NULL(graph);
  256. MS_EXCEPTION_IF_NULL(mem_manager_);
  257. MS_LOG(INFO) << "AssignStaticMemoryInput start";
  258. auto graph_inputs = graph->inputs();
  259. auto graph_valid_input = graph->valid_inputs();
  260. graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  261. std::vector<AnfNodePtr> need_alloc_nodes;
  262. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  263. auto item = graph_inputs[i];
  264. MS_EXCEPTION_IF_NULL(item);
  265. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  266. continue;
  267. }
  268. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  269. auto outs = AnfAlgo::GetAllOutput(item);
  270. for (auto &out : outs) {
  271. MS_EXCEPTION_IF_NULL(out);
  272. if (!out->isa<Parameter>()) {
  273. continue;
  274. }
  275. if (NodeOutputDeviceAddressExist(out, 0)) {
  276. continue;
  277. }
  278. need_alloc_nodes.push_back(out);
  279. }
  280. }
  281. if (!item->isa<Parameter>()) {
  282. continue;
  283. }
  284. if (NodeOutputDeviceAddressExist(item, 0)) {
  285. continue;
  286. }
  287. need_alloc_nodes.push_back(item);
  288. }
  289. #if (ENABLE_CPU && !_WIN32)
  290. bool ps_cache_check = false;
  291. #endif
  292. for (auto &item : need_alloc_nodes) {
  293. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  294. for (size_t index = 0; index < output_size; index++) {
  295. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  296. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  297. if (output_type_id == kTypeUnknown) {
  298. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  299. continue;
  300. }
  301. DeviceAddressPtr device_address = nullptr;
  302. #if (ENABLE_CPU && !_WIN32)
  303. const std::string &param_name = item->fullname_with_scope();
  304. if (ps::ps_cache_instance.IsHashTable(param_name)) {
  305. MS_LOG(INFO) << "Parameter(" << param_name << ")"
  306. << " enables the embeddingLookup cache in parameter server training mode.";
  307. // PS embeddingLookup cache check.
  308. if (!ps_cache_check) {
  309. CheckIfSupportPSEmbeddingCache(graph);
  310. ps_cache_check = true;
  311. }
  312. const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
  313. MS_EXCEPTION_IF_NULL(address.addr);
  314. device_address =
  315. CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  316. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  317. continue;
  318. }
  319. #endif
  320. auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  321. device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  322. MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
  323. << " index: " << index << " size: " << tensor_size;
  324. if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
  325. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
  326. }
  327. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  328. }
  329. }
  330. MS_LOG(INFO) << "AssignStaticMemoryInput end";
  331. }
  332. void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
  333. MS_EXCEPTION_IF_NULL(graph);
  334. MS_LOG(INFO) << "AssignStaticMemoryOutput start";
  335. auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
  336. std::vector<session::KernelWithIndex> non_communication_op;
  337. // Assign Communicate Op Memory firstly.
  338. for (const auto &node : nodes) {
  339. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  340. MS_EXCEPTION_IF_NULL(item_with_index.first);
  341. if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
  342. continue;
  343. }
  344. if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
  345. AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
  346. } else {
  347. non_communication_op.emplace_back(item_with_index);
  348. }
  349. }
  350. for (const auto &item_with_index : non_communication_op) {
  351. MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope();
  352. AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
  353. }
  354. MS_LOG(INFO) << "AssignStaticMemoryOutput end";
  355. }
  356. void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
  357. MS_EXCEPTION_IF_NULL(graph);
  358. auto &kernels = graph->execution_order();
  359. for (auto &kernel : kernels) {
  360. MS_EXCEPTION_IF_NULL(kernel);
  361. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  362. MS_EXCEPTION_IF_NULL(kernel_mod);
  363. auto output_sizes = kernel_mod->GetOutputSizeList();
  364. if (output_sizes.empty()) {
  365. MS_LOG(INFO) << "This kernel has no output size.";
  366. continue;
  367. }
  368. for (size_t i = 0; i < output_sizes.size(); ++i) {
  369. session::AnfWithOutIndex out_pair(kernel, i);
  370. if (graph->IsInRefOutputMap(out_pair)) {
  371. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  372. MS_EXCEPTION_IF_NULL(origin_pair.first);
  373. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  374. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  375. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  376. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  377. MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
  378. MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  379. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  380. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  381. }
  382. }
  383. }
  384. }
  385. }
  386. void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
  387. AssignCommunicationNodeInputMem(type, node);
  388. AssignCommunicationNodeOutputMem(type, node);
  389. AssignWorkSpaceMem(type, node);
  390. }
  391. void KernelRuntime::GenKernelEvents(const session::KernelGraph *graph) {
  392. MS_EXCEPTION_IF_NULL(graph);
  393. auto &kernels = graph->execution_order();
  394. if (kernels.empty() || graph_kernel_events_map_.find(graph->graph_id()) != graph_kernel_events_map_.end()) {
  395. return;
  396. }
  397. auto kernel_events =
  398. std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
  399. auto &kernel_pre_run_events = kernel_events.first;
  400. auto &kernel_post_run_events = kernel_events.second;
  401. kernel_pre_run_events.resize(kernels.size());
  402. kernel_post_run_events.resize(kernels.size());
  403. for (size_t i = 0; i < kernels.size(); ++i) {
  404. auto &kernel = kernels[i];
  405. if (!AnfAlgo::IsCommunicationOp(kernel)) {
  406. continue;
  407. }
  408. auto pre_event = CreateDeviceEvent();
  409. auto post_event = CreateDeviceEvent();
  410. pre_event->set_wait_stream(communication_stream_);
  411. pre_event->set_record_stream(stream_);
  412. post_event->set_wait_stream(stream_);
  413. post_event->set_record_stream(communication_stream_);
  414. kernel_pre_run_events[i].emplace_back([pre_event]() {
  415. pre_event->RecordEvent();
  416. pre_event->WaitEvent();
  417. });
  418. kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); });
  419. bool found_nearest_child = false;
  420. for (size_t j = i + 1; j < kernels.size(); ++j) {
  421. auto &child = kernels[j];
  422. MS_EXCEPTION_IF_NULL(child);
  423. auto input_size = child->inputs().size() - 1;
  424. for (size_t k = 0; k < input_size; ++k) {
  425. auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0);
  426. if (kernel_index.first == kernel) {
  427. found_nearest_child = true;
  428. break;
  429. }
  430. }
  431. if (found_nearest_child) {
  432. kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); });
  433. break;
  434. }
  435. }
  436. if (!found_nearest_child) {
  437. kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); });
  438. }
  439. }
  440. graph_kernel_events_map_[graph->graph_id()] = std::move(kernel_events);
  441. }
  442. void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
  443. MS_EXCEPTION_IF_NULL(node);
  444. MS_EXCEPTION_IF_NULL(mem_manager_);
  445. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  446. MS_EXCEPTION_IF_NULL(kernel_mod);
  447. auto output_sizes = kernel_mod->GetOutputSizeList();
  448. if (output_sizes.empty()) {
  449. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  450. return;
  451. }
  452. auto context_ptr = MsContext::GetInstance();
  453. MS_EXCEPTION_IF_NULL(context_ptr);
  454. size_t total_size = 0;
  455. size_t output_index = 0;
  456. std::vector<size_t> align_size_list;
  457. for (uint64_t mem_size : output_sizes) {
  458. if (AnfAlgo::OutputAddrExist(node, output_index++)) {
  459. MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
  460. return;
  461. }
  462. if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
  463. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  464. }
  465. total_size += mem_size;
  466. align_size_list.emplace_back(mem_size);
  467. }
  468. if (align_size_list.empty()) {
  469. return;
  470. }
  471. if (type == kReuseDynamicMem) {
  472. // reuse communication op's all outputs' memory
  473. type = kReuseDynamicCommMem;
  474. }
  475. if (type == kReuseDynamicCommMem || type == kSomasReuseDynamicMem) {
  476. bool not_reuse = KernelMemNotReuse(node);
  477. if (not_reuse) {
  478. type = kDynamicMem;
  479. MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
  480. }
  481. }
  482. uint8_t *output_ptr = nullptr;
  483. for (size_t j = 0; j < align_size_list.size(); ++j) {
  484. std::string output_format = AnfAlgo::GetOutputFormat(node, j);
  485. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
  486. auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
  487. MS_EXCEPTION_IF_NULL(address);
  488. if (output_ptr == nullptr) {
  489. output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
  490. MS_EXCEPTION_IF_NULL(output_ptr);
  491. } else {
  492. address->set_ptr(output_ptr);
  493. }
  494. AnfAlgo::SetOutputAddr(address, j, node.get());
  495. output_ptr += align_size_list[j];
  496. }
  497. }
  498. bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; }
  499. DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
  500. MS_EXCEPTION_IF_NULL(anf_node);
  501. if (!anf_node->isa<CNode>()) {
  502. MS_LOG(EXCEPTION) << "anf_node should be a cnode";
  503. }
  504. auto cnode = anf_node->cast<CNodePtr>();
  505. if (opt::IsNopNode(cnode)) {
  506. const size_t kNopNodeInputSize = 2;
  507. if (cnode->size() != kNopNodeInputSize) {
  508. MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size();
  509. }
  510. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
  511. return PreAssignCNodeMemory(input_node_with_index.first, input_node_with_index.second);
  512. }
  513. auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
  514. MS_EXCEPTION_IF_NULL(kernel_mod);
  515. auto output_sizes = kernel_mod->GetOutputSizeList();
  516. if (output_sizes.size() <= index) {
  517. MS_LOG(EXCEPTION) << "Previous node output size < node index";
  518. }
  519. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
  520. auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
  521. auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
  522. AnfAlgo::SetOutputAddr(address, index, anf_node.get());
  523. return address;
  524. }
  525. void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
  526. auto context_ptr = MsContext::GetInstance();
  527. MS_EXCEPTION_IF_NULL(context_ptr);
  528. MS_EXCEPTION_IF_NULL(node);
  529. MS_EXCEPTION_IF_NULL(mem_manager_);
  530. size_t total_size = 0;
  531. std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
  532. size_t input_num = AnfAlgo::GetInputTensorNum(node);
  533. for (size_t i = 0; i < input_num; ++i) {
  534. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i, true);
  535. auto input_node = input_node_with_index.first;
  536. if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
  537. MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
  538. return;
  539. }
  540. DeviceAddressPtr address = nullptr;
  541. if (input_node->isa<CNode>()) {
  542. address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
  543. } else {
  544. MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
  545. }
  546. MS_EXCEPTION_IF_NULL(address);
  547. auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
  548. total_size += mem_size;
  549. addr_size.emplace_back(address, mem_size);
  550. }
  551. if (addr_size.empty()) {
  552. return;
  553. }
  554. if (type == kReuseDynamicMem || type == kSomasReuseDynamicMem) {
  555. bool not_reuse = KernelMemNotReuse(node);
  556. if (not_reuse) {
  557. type = kDynamicMem;
  558. MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
  559. }
  560. }
  561. auto cnode = node->cast<CNodePtr>();
  562. MS_EXCEPTION_IF_NULL(cnode);
  563. if (cnode->inputs().size() < 2) {
  564. // communication node's input should contain itself and at least on input
  565. MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
  566. return;
  567. }
  568. auto first_input_node = cnode->input(1);
  569. auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
  570. uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
  571. addr_size[0].first, true);
  572. for (const auto &iter : addr_size) {
  573. MS_EXCEPTION_IF_NULL(iter.first);
  574. iter.first->set_ptr(input_ptr);
  575. input_ptr += iter.second;
  576. }
  577. }
  578. void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
  579. MS_EXCEPTION_IF_NULL(node);
  580. MS_EXCEPTION_IF_NULL(mem_manager_);
  581. if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) {
  582. MS_LOG(INFO) << "GetNext disable mem_reuse";
  583. type = kDynamicMem;
  584. }
  585. if (node->isa<CNode>()) {
  586. bool independent = AnfAlgo::IsIndependentNode(node->cast<CNodePtr>());
  587. if (independent && (type == kReuseDynamicMem)) {
  588. MS_LOG(INFO) << "Independent node " << node->fullname_with_scope() << " disable memory reuse";
  589. type = kDynamicMem;
  590. }
  591. }
  592. if (type == kReuseDynamicMem || type == kSomasReuseDynamicMem) {
  593. bool not_reuse = KernelMemNotReuse(node);
  594. if (not_reuse) {
  595. type = kDynamicMem;
  596. MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
  597. }
  598. }
  599. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  600. MS_EXCEPTION_IF_NULL(kernel_mod);
  601. auto output_sizes = kernel_mod->GetOutputSizeList();
  602. if (output_sizes.empty()) {
  603. return;
  604. }
  605. for (size_t i = 0; i < output_sizes.size(); ++i) {
  606. if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
  607. continue;
  608. }
  609. if (NodeOutputDeviceAddressExist(node, i)) {
  610. MS_LOG(INFO) << "Already malloc index:" << i;
  611. continue;
  612. }
  613. MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
  614. std::string output_format = AnfAlgo::GetOutputFormat(node, i);
  615. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
  616. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  617. MS_EXCEPTION_IF_NULL(device_address);
  618. uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
  619. MS_EXCEPTION_IF_NULL(ptr);
  620. device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
  621. AnfAlgo::SetOutputAddr(device_address, i, node.get());
  622. }
  623. }
  624. void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
  625. size_t output_idx) {
  626. MS_EXCEPTION_IF_NULL(value_node);
  627. MS_EXCEPTION_IF_NULL(node_value);
  628. MS_EXCEPTION_IF_NULL(mem_manager_);
  629. auto ms_context = MsContext::GetInstance();
  630. MS_EXCEPTION_IF_NULL(ms_context);
  631. std::vector<tensor::TensorPtr> tensors;
  632. TensorValueToTensor(node_value, &tensors);
  633. // Graph id should be passed to record static memory if profiling is enabled.
  634. auto kernel_info = static_cast<device::KernelInfo *>(value_node->kernel_info());
  635. MS_EXCEPTION_IF_NULL(kernel_info);
  636. uint32_t graph_id = kernel_info->graph_id();
  637. for (const auto &tensor : tensors) {
  638. if (tensor == nullptr) {
  639. MS_LOG(WARNING) << "Tensor is null";
  640. return;
  641. }
  642. auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  643. if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
  644. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  645. value_node.get());
  646. continue;
  647. }
  648. size_t tensor_size = tensor->data().nbytes();
  649. auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
  650. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  651. if (output_type_id == kTypeUnknown) {
  652. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  653. }
  654. auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  655. DeviceAddressPtr address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
  656. MS_EXCEPTION_IF_NULL(address);
  657. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
  658. !mem_manager_->MallocMemFromMemPool(address, node_size)) {
  659. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << node_size;
  660. } else if (mem_manager_->MallocMem(kStaticMem, node_size, address, graph_id) == nullptr) {
  661. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
  662. }
  663. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  664. if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
  665. tensor->data_c())) {
  666. MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
  667. << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
  668. << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  669. }
  670. }
  671. return;
  672. }
  673. void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
  674. MS_EXCEPTION_IF_NULL(graph);
  675. MS_EXCEPTION_IF_NULL(mem_manager_);
  676. MS_LOG(INFO) << "AssignStaticMemoryValueNode start";
  677. auto ms_context = MsContext::GetInstance();
  678. MS_EXCEPTION_IF_NULL(ms_context);
  679. // order the value nodes
  680. std::map<std::string, ValueNodePtr> value_nodes_map;
  681. for (auto &node : graph->graph_value_nodes()) {
  682. value_nodes_map[node->fullname_with_scope()] = node;
  683. }
  684. for (auto &item : value_nodes_map) {
  685. auto value_node = item.second;
  686. MS_EXCEPTION_IF_NULL(value_node);
  687. if (NodeOutputDeviceAddressExist(value_node, 0)) {
  688. MS_LOG(DEBUG) << "value_node[" << value_node->DebugString() << "] address already exist";
  689. continue;
  690. }
  691. auto &node_value = value_node->value();
  692. MS_EXCEPTION_IF_NULL(node_value);
  693. MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
  694. if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
  695. AssignValueNodeTensor(value_node, node_value, 0);
  696. } else if (node_value->isa<StringImm>()) {
  697. auto value = GetValue<std::string>(node_value);
  698. size_t tensor_size = value.size();
  699. DeviceAddressPtr address = nullptr;
  700. address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  701. MS_EXCEPTION_IF_NULL(address);
  702. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
  703. !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
  704. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
  705. } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address, graph->graph_id()) == nullptr) {
  706. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
  707. }
  708. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  709. ShapeVector shape = {1, SizeToLong(tensor_size)};
  710. if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  711. MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
  712. }
  713. }
  714. }
  715. MS_LOG(INFO) << "AssignStaticMemoryValueNode end";
  716. }
  717. void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
  718. MS_EXCEPTION_IF_NULL(graph);
  719. MS_EXCEPTION_IF_NULL(mem_manager_);
  720. auto context_ptr = MsContext::GetInstance();
  721. MS_EXCEPTION_IF_NULL(context_ptr);
  722. bool is_enable_mem_reuse = EnvConfigParser::GetInstance().GetSysMemreuse();
  723. auto mem_type = kDynamicMem;
  724. auto &dump_json_parser = DumpJsonParser::GetInstance();
  725. if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
  726. mindspore::EnvConfigParser::GetInstance().SetSysMemreuse(false);
  727. is_enable_mem_reuse = false;
  728. MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
  729. }
  730. if (is_enable_mem_reuse) {
  731. MS_LOG(INFO) << "Memory Reuse is enable...";
  732. #ifdef MEM_REUSE_DEBUG
  733. mem_manager_->MallocReusedDynamicMem(graph);
  734. mem_type = kReuseDynamicMem;
  735. #else
  736. mem_manager_->MallocSomasDynamicMem(graph);
  737. mem_type = kSomasReuseDynamicMem;
  738. #endif
  739. } else {
  740. MS_LOG(INFO) << "Memory Reuse is disable...";
  741. }
  742. auto &execution_nodes = graph->execution_order();
  743. std::vector<CNodePtr> compute_nodes;
  744. // communication nodes first
  745. for (auto &node : execution_nodes) {
  746. if (AnfAlgo::IsCommunicationOp(node)) {
  747. // skip if the memory is already allocated
  748. AssignCommunicationNodeMem(mem_type, node);
  749. } else {
  750. compute_nodes.emplace_back(node);
  751. }
  752. }
  753. // then compute nodes
  754. for (auto &node : compute_nodes) {
  755. AssignNodeOutputMem(mem_type, node, kGetAllOuts);
  756. AssignWorkSpaceMem(mem_type, node);
  757. }
  758. }
  759. void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
  760. MS_EXCEPTION_IF_NULL(node);
  761. MS_EXCEPTION_IF_NULL(mem_manager_);
  762. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  763. MS_EXCEPTION_IF_NULL(kernel_mod);
  764. size_t index = 0;
  765. for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
  766. if (AnfAlgo::WorkspaceAddrExist(node, index)) {
  767. MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
  768. return;
  769. }
  770. auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
  771. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
  772. index++;
  773. }
  774. }
  775. void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
  776. AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
  777. AddressPtrList *kernel_outputs) {
  778. MS_EXCEPTION_IF_NULL(kernel);
  779. MS_EXCEPTION_IF_NULL(kernel_inputs);
  780. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  781. MS_EXCEPTION_IF_NULL(kernel_outputs);
  782. auto cnode = kernel->cast<CNodePtr>();
  783. MS_EXCEPTION_IF_NULL(cnode);
  784. if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
  785. return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
  786. }
  787. auto ms_context = MsContext::GetInstance();
  788. MS_EXCEPTION_IF_NULL(ms_context);
  789. auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
  790. size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
  791. for (size_t i = 0; i < input_num; ++i) {
  792. auto op_name = AnfAlgo::GetCNodeName(cnode);
  793. constexpr auto none_placeholder_index = 3;
  794. if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
  795. continue;
  796. }
  797. if (op_name == kDynamicGRUV2OpName) {
  798. auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
  799. auto item = std::find(none_index.begin(), none_index.end(), i);
  800. if (item != none_index.end()) {
  801. continue;
  802. }
  803. }
  804. auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
  805. auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, visit_nop_node);
  806. MS_EXCEPTION_IF_NULL(device_address);
  807. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  808. MS_EXCEPTION_IF_NULL(input);
  809. input->addr = device_address->ptr_;
  810. MS_EXCEPTION_IF_NULL(input->addr);
  811. input->size = device_address->size_;
  812. kernel_inputs->emplace_back(input);
  813. }
  814. for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
  815. auto device_address = AnfAlgo::GetOutputAddr(kernel, i, visit_nop_node);
  816. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  817. MS_EXCEPTION_IF_NULL(output);
  818. output->addr = device_address->ptr_;
  819. MS_EXCEPTION_IF_NULL(output->addr);
  820. output->size = device_address->size_;
  821. kernel_outputs->emplace_back(output);
  822. }
  823. for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
  824. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  825. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  826. MS_EXCEPTION_IF_NULL(workspace);
  827. workspace->addr = device_address->ptr_;
  828. MS_EXCEPTION_IF_NULL(workspace->addr);
  829. workspace->size = device_address->size_;
  830. kernel_workspaces->emplace_back(workspace);
  831. }
  832. }
  833. void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
  834. if (cnode->inputs().size() != 2) {
  835. MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
  836. }
  837. MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
  838. auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
  839. // set clean output address
  840. if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
  841. #if defined(__APPLE__)
  842. auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
  843. #else
  844. auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
  845. #endif
  846. for (auto index : clean_output_indexes) {
  847. auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
  848. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  849. MS_EXCEPTION_IF_NULL(input);
  850. input->addr = device_address->ptr_;
  851. MS_EXCEPTION_IF_NULL(input->addr);
  852. input->size = device_address->size_;
  853. kernel_inputs->emplace_back(input);
  854. }
  855. MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
  856. }
  857. // set clean workspace address
  858. if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
  859. #if defined(__APPLE__)
  860. auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
  861. #else
  862. auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
  863. #endif
  864. for (const auto &index : clean_workspaces_indexes) {
  865. auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
  866. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  867. MS_EXCEPTION_IF_NULL(workspace);
  868. workspace->addr = device_address->ptr_;
  869. MS_EXCEPTION_IF_NULL(workspace->addr);
  870. workspace->size = device_address->size_;
  871. kernel_inputs->emplace_back(workspace);
  872. }
  873. }
  874. }
  875. void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &kernel_events,
  876. size_t index) {
  877. if (index >= kernel_events.size()) {
  878. return;
  879. }
  880. for (auto &event : kernel_events[index]) {
  881. event();
  882. }
  883. }
  884. bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
  885. const auto &kernels = graph.execution_order();
  886. std::vector<DynamicKernelPtr> dynamic_kernel_list;
  887. auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
  888. if (iter != graph_dynamic_kernel_map_.end()) {
  889. dynamic_kernel_list = iter->second;
  890. }
  891. if (!dynamic_kernel_list.empty() && dynamic_kernel_list.size() != kernels.size()) {
  892. MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size()
  893. << " should be equal to the size of kernels " << kernels.size();
  894. }
  895. std::vector<std::vector<std::function<void()>>> kernel_pre_run_events;
  896. std::vector<std::vector<std::function<void()>>> kernel_post_run_events;
  897. auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
  898. if (events_iter != graph_kernel_events_map_.end()) {
  899. kernel_pre_run_events = events_iter->second.first;
  900. kernel_post_run_events = events_iter->second.second;
  901. }
  902. for (size_t i = 0; i < kernels.size(); ++i) {
  903. LaunchKernelEvent(kernel_pre_run_events, i);
  904. if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr &&
  905. dynamic_kernel_list[i]->is_dynamic_shape()) {
  906. dynamic_kernel_list[i]->InferShape();
  907. dynamic_kernel_list[i]->UpdateArgs();
  908. dynamic_kernel_list[i]->Execute();
  909. if (!SyncStream()) {
  910. MS_LOG(ERROR) << "SyncStream failed";
  911. return false;
  912. }
  913. dynamic_kernel_list[i]->PostExecute();
  914. } else {
  915. auto &kernel = kernels[i];
  916. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  917. MS_EXCEPTION_IF_NULL(kernel_mod);
  918. // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
  919. // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata. And
  920. // hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
  921. if (AnfAlgo::HasNodeAttr("nop_op", kernel)) {
  922. for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
  923. auto real_input = AnfAlgo::GetRealInputIndex(kernel, idx);
  924. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input);
  925. AnfAlgo::SetOutputAddr(device_address, idx, kernel.get());
  926. }
  927. continue;
  928. }
  929. AddressPtrList kernel_inputs;
  930. AddressPtrList kernel_workspaces;
  931. AddressPtrList kernel_outputs;
  932. GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  933. bool ret;
  934. if (AnfAlgo::IsCommunicationOp(kernel)) {
  935. ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
  936. } else {
  937. ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  938. }
  939. if (!ret) {
  940. MS_LOG(ERROR) << "Launch kernel failed.";
  941. return false;
  942. }
  943. KernelLaunchProfiling(kernels[i]->fullname_with_scope());
  944. }
  945. LaunchKernelEvent(kernel_post_run_events, i);
  946. }
  947. return true;
  948. }
  949. bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
  950. MS_EXCEPTION_IF_NULL(graph);
  951. if (!LaunchKernelMod(*graph)) {
  952. MS_LOG(ERROR) << "LaunchKernelMod failed!";
  953. return false;
  954. }
  955. auto ms_context = MsContext::GetInstance();
  956. MS_EXCEPTION_IF_NULL(ms_context);
  957. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  958. if (!SyncStream()) {
  959. MS_LOG(ERROR) << "SyncStream failed";
  960. return false;
  961. }
  962. }
  963. return true;
  964. }
  965. void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
  966. const std::unordered_set<ValueNodePtr> &, const std::vector<CNodePtr> &) {
  967. MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
  968. }
  969. void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
  970. const std::unordered_set<ValueNodePtr> &value_nodes,
  971. const std::vector<CNodePtr> &execution_order) {
  972. // clear input parameter output address.
  973. for (const auto &input_node : inputs) {
  974. MS_EXCEPTION_IF_NULL(input_node);
  975. if (!input_node->isa<Parameter>()) {
  976. continue;
  977. }
  978. auto parameter = input_node->cast<ParameterPtr>();
  979. MS_EXCEPTION_IF_NULL(parameter);
  980. parameter->DecreaseUsedGraphCount();
  981. // Only the parameter has no graph used, then clear the output address.
  982. if (parameter->used_graph_count() != 0) {
  983. continue;
  984. }
  985. size_t output_num = AnfAlgo::GetOutputTensorNum(input_node);
  986. for (size_t index = 0; index < output_num; ++index) {
  987. if (!AnfAlgo::OutputAddrExist(input_node, index)) {
  988. continue;
  989. }
  990. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  991. }
  992. }
  993. // clear input value node output address.
  994. for (const auto &value_node : value_nodes) {
  995. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  996. continue;
  997. }
  998. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  999. }
  1000. // clear cnode output address.
  1001. for (const auto &cnode : execution_order) {
  1002. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  1003. for (size_t index = 0; index < output_num; ++index) {
  1004. if (!AnfAlgo::OutputAddrExist(cnode, index)) {
  1005. continue;
  1006. }
  1007. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  1008. }
  1009. }
  1010. }
  1011. bool KernelRuntime::LaunchTaskBasedOnSingleKernel(const kernel::KernelModPtr &kernel_mod_ptr,
  1012. const AddressPtrList &kernel_inputs,
  1013. const AddressPtrList &kernel_outputs,
  1014. const AddressPtrList &kernel_workspaces) const {
  1015. MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
  1016. auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  1017. if (!ret) {
  1018. MS_LOG(ERROR) << "Launch kernel failed.";
  1019. return false;
  1020. }
  1021. return true;
  1022. }
  1023. DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type) {
  1024. auto device_address = CreateDeviceAddress(nullptr, size, format, type);
  1025. MS_EXCEPTION_IF_NULL(device_address);
  1026. MS_EXCEPTION_IF_NULL(mem_manager_);
  1027. auto base_ptr = mem_manager_->MallocMem(kStaticMem, size, device_address);
  1028. MS_EXCEPTION_IF_NULL(base_ptr);
  1029. return device_address;
  1030. }
  1031. #if (ENABLE_CPU && !_WIN32)
  1032. void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
  1033. AnfNodePtr *const first_cache_input_index,
  1034. size_t *const first_cache_size) {
  1035. MS_EXCEPTION_IF_NULL(graph);
  1036. for (const auto &kernel : graph->execution_order()) {
  1037. MS_EXCEPTION_IF_NULL(kernel);
  1038. auto kernel_name = AnfAlgo::GetCNodeName(kernel);
  1039. if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
  1040. continue;
  1041. }
  1042. auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
  1043. auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
  1044. MS_EXCEPTION_IF_NULL(input_param.first);
  1045. MS_EXCEPTION_IF_NULL(input_index.first);
  1046. auto param_name = input_param.first->fullname_with_scope();
  1047. if (!ps::ps_cache_instance.IsHashTable(param_name)) {
  1048. continue;
  1049. }
  1050. auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
  1051. while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
  1052. input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
  1053. MS_EXCEPTION_IF_NULL(input_index.first);
  1054. }
  1055. auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first);
  1056. if (input_index.first->isa<CNode>() && (input_index_node_name != kGetNextOpName)) {
  1057. bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
  1058. if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
  1059. (full_batch && (input_index_node_name != kMinimumOpName))) {
  1060. MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
  1061. << ") cache is from " << input_index.first->fullname_with_scope();
  1062. MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
  1063. "parameter server training mode.";
  1064. }
  1065. }
  1066. *first_cache_input_index = input_index.first;
  1067. *first_cache_size = size;
  1068. MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from "
  1069. << input_index.first->fullname_with_scope() << ", the cache size is " << size;
  1070. return;
  1071. }
  1072. }
  1073. void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
  1074. MS_EXCEPTION_IF_NULL(node);
  1075. auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true);
  1076. while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
  1077. pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
  1078. MS_EXCEPTION_IF_NULL(pre_node.first);
  1079. }
  1080. if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
  1081. MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode";
  1082. }
  1083. pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
  1084. while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
  1085. pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
  1086. MS_EXCEPTION_IF_NULL(pre_node.first);
  1087. }
  1088. if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) {
  1089. MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices "
  1090. "value can not be changed before delivering to kernel[Unique] in parameter server cache mode.";
  1091. }
  1092. }
  1093. void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) {
  1094. MS_EXCEPTION_IF_NULL(graph);
  1095. AnfNodePtr first_cache_input_index = nullptr;
  1096. size_t first_cache_size = 0;
  1097. GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size);
  1098. MS_EXCEPTION_IF_NULL(first_cache_input_index);
  1099. for (const auto &kernel : graph->execution_order()) {
  1100. MS_EXCEPTION_IF_NULL(kernel);
  1101. auto kernel_name = AnfAlgo::GetCNodeName(kernel);
  1102. if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
  1103. continue;
  1104. }
  1105. auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
  1106. auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
  1107. MS_EXCEPTION_IF_NULL(input_param.first);
  1108. MS_EXCEPTION_IF_NULL(input_index.first);
  1109. if (!input_param.first->isa<Parameter>()) {
  1110. continue;
  1111. }
  1112. auto param_name = input_param.first->fullname_with_scope();
  1113. if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
  1114. CheckSparsePSEmbeddingCache(kernel);
  1115. }
  1116. while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
  1117. input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
  1118. MS_EXCEPTION_IF_NULL(input_index.first);
  1119. }
  1120. if (input_index.first == first_cache_input_index) {
  1121. if (!ps::ps_cache_instance.IsHashTable(param_name)) {
  1122. MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
  1123. MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
  1124. "same time when one of them enables cache in parameter server training mode.";
  1125. }
  1126. auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
  1127. if (size != first_cache_size) {
  1128. MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
  1129. << ") is not the same as other embeddingLookup cache size(" << first_cache_size << ").";
  1130. MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
  1131. }
  1132. } else if (ps::ps_cache_instance.IsHashTable(param_name)) {
  1133. MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
  1134. << input_index.first->fullname_with_scope();
  1135. MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
  1136. "parameter server training mode.";
  1137. } else if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kGetNextOpName)) {
  1138. MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
  1139. MS_LOG(EXCEPTION) << "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
  1140. "the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
  1141. "context setting in parameter server training mode.";
  1142. }
  1143. }
  1144. }
  1145. #endif
  1146. } // namespace device
  1147. } // namespace mindspore