You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime.cc 30 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/kernel_runtime.h"
  17. #include <vector>
  18. #include <utility>
  19. #include <numeric>
  20. #include <functional>
  21. #include "common/utils.h"
  22. #include "common/trans.h"
  23. #include "utils/utils.h"
  24. #include "utils/context/ms_context.h"
  25. #include "operator/ops.h"
  26. #include "pipeline/parse/python_adapter.h"
  27. #include "session/kernel_graph.h"
  28. #include "session/anf_runtime_algorithm.h"
  29. #include "kernel/common_utils.h"
  30. #include "kernel/oplib/oplib.h"
  31. #include "ir/value.h"
  32. #include "pre_activate/common/helper.h"
  33. using mindspore::kernel::Address;
  34. using mindspore::kernel::AddressPtr;
  35. namespace mindspore {
  36. namespace device {
  37. KernelRuntime::~KernelRuntime() {
  38. #ifdef ENABLE_DUMP_E2E
  39. dump_conf_ptr_ = nullptr;
  40. #endif
  41. }
  42. bool KernelRuntime::Run(session::KernelGraph *graph) {
  43. bool ret = false;
  44. auto context_ptr = MsContext::GetInstance();
  45. MS_EXCEPTION_IF_NULL(context_ptr);
  46. #if defined(_WIN32) || defined(_WIN64)
  47. auto start_time = std::chrono::steady_clock::now();
  48. #else
  49. struct timeval start_time, end_time;
  50. (void)gettimeofday(&start_time, nullptr);
  51. #endif
  52. bool is_task_sink = context_ptr->enable_task_sink();
  53. if (is_task_sink) {
  54. ret = RunTask(graph);
  55. } else {
  56. ret = LaunchKernel(graph);
  57. }
  58. #if defined(_WIN32) || defined(_WIN64)
  59. auto end_time = std::chrono::steady_clock::now();
  60. std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
  61. MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
  62. #else
  63. (void)gettimeofday(&end_time, nullptr);
  64. const uint64_t kUSecondInSecond = 1000000;
  65. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  66. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  67. MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
  68. #endif
  69. return ret;
  70. }
  71. // for D to impl
  72. bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
  73. if (graph != nullptr) {
  74. return true;
  75. }
  76. return false;
  77. }
  78. // for D to impl
  79. bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
  80. if (graph != nullptr) {
  81. return true;
  82. }
  83. return false;
  84. }
  85. // for D to impl
  86. bool KernelRuntime::GenTask(const session::KernelGraph *graph) {
  87. if (graph != nullptr) {
  88. return true;
  89. }
  90. return false;
  91. }
  92. bool KernelRuntime::LoadTask(const session::KernelGraph *graph) {
  93. if (graph != nullptr) {
  94. return true;
  95. }
  96. return false;
  97. }
  98. // for D to impl
  99. bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
  100. if (graph != nullptr) {
  101. return true;
  102. }
  103. return false;
  104. }
  105. bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
  106. MS_EXCEPTION_IF_NULL(kernel);
  107. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  108. return true;
  109. }
  110. return false;
  111. }
  112. size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
  113. MS_EXCEPTION_IF_NULL(node);
  114. if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
  115. MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
  116. << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
  117. }
  118. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  119. if (output_type_id == kTypeUnknown) {
  120. output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  121. }
  122. size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
  123. std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
  124. auto format = AnfAlgo::GetOutputFormat(node, output_index);
  125. if (shape.empty() && format != kOpFormat_DEFAULT) {
  126. shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
  127. shape = trans::TransShapeToDevice(shape, format);
  128. }
  129. // scalar's output shape is a empty vector
  130. size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
  131. return tensor_size;
  132. }
  133. void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
  134. auto context_ptr = MsContext::GetInstance();
  135. MS_EXCEPTION_IF_NULL(context_ptr);
  136. MS_EXCEPTION_IF_NULL(mem_manager_);
  137. mem_manager_->ResetDynamicMemory();
  138. AssignStaticMemory(graph);
  139. AssignDynamicMemory(graph);
  140. UpdateRefNodeOutputMem(graph);
  141. }
  142. void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  143. session::KernelGraph *graph) {
  144. MS_EXCEPTION_IF_NULL(graph);
  145. RunOpAssignInputMemory(input_tensors, graph);
  146. AssignStaticMemoryValueNode(graph);
  147. for (const auto &cnode : graph->execution_order()) {
  148. RunOpAssignOutputMemory(cnode);
  149. RunOpAssignWorkSpaceMemory(cnode);
  150. }
  151. UpdateRefNodeOutputMem(graph);
  152. }
  153. void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) {
  154. MS_EXCEPTION_IF_NULL(graph);
  155. // clear input parameter memory resource
  156. for (const auto &input_node : graph->inputs()) {
  157. MS_EXCEPTION_IF_NULL(input_node);
  158. AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
  159. }
  160. // clear input value node memory resource
  161. for (const auto &value_node : graph->graph_value_nodes()) {
  162. MS_EXCEPTION_IF_NULL(value_node);
  163. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  164. }
  165. for (const auto &cnode : graph->execution_order()) {
  166. MS_EXCEPTION_IF_NULL(cnode);
  167. // clear output memory resource
  168. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
  169. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  170. }
  171. // clear workspace memory resource
  172. auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
  173. MS_EXCEPTION_IF_NULL(kernel_mod);
  174. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  175. for (size_t index = 0; index < workspace_lists.size(); ++index) {
  176. AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
  177. }
  178. }
  179. }
  180. void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) {
  181. AssignStaticMemoryInput(graph);
  182. AssignStaticMemoryValueNode(graph);
  183. AssignStaticMemoryOutput(graph);
  184. }
  185. void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  186. const session::KernelGraph *graph) {
  187. MS_EXCEPTION_IF_NULL(graph);
  188. MS_EXCEPTION_IF_NULL(mem_manager_);
  189. if (input_tensors.size() != graph->inputs().size()) {
  190. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
  191. << " should be equal to graph input parameter size " << graph->inputs().size();
  192. }
  193. for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) {
  194. auto item = graph->inputs()[input_index];
  195. MS_EXCEPTION_IF_NULL(item);
  196. if (!item->isa<Parameter>()) {
  197. continue;
  198. }
  199. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  200. for (size_t index = 0; index < output_size; index++) {
  201. MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
  202. if (input_tensors[input_index]->device_address().get() != nullptr) {
  203. AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get());
  204. continue;
  205. }
  206. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  207. if (output_type_id == kTypeUnknown) {
  208. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  209. }
  210. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  211. auto device_address =
  212. CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  213. MS_EXCEPTION_IF_NULL(device_address);
  214. MS_EXCEPTION_IF_NULL(mem_manager_);
  215. auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
  216. if (!ret) {
  217. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  218. }
  219. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  220. }
  221. }
  222. }
  223. void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
  224. MS_EXCEPTION_IF_NULL(kernel);
  225. MS_EXCEPTION_IF_NULL(mem_manager_);
  226. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  227. MS_EXCEPTION_IF_NULL(kernel_mod);
  228. auto output_sizes = kernel_mod->GetOutputSizeList();
  229. if (output_sizes.empty()) {
  230. return;
  231. }
  232. for (size_t i = 0; i < output_sizes.size(); ++i) {
  233. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  234. continue;
  235. }
  236. if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
  237. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  238. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  239. continue;
  240. }
  241. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  242. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  243. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  244. device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
  245. MS_EXCEPTION_IF_NULL(device_address);
  246. auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
  247. if (!ret) {
  248. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  249. }
  250. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  251. }
  252. }
  253. void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
  254. MS_EXCEPTION_IF_NULL(kernel);
  255. MS_EXCEPTION_IF_NULL(mem_manager_);
  256. if (kernel->isa<CNode>()) {
  257. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  258. MS_EXCEPTION_IF_NULL(kernel_mod);
  259. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  260. for (size_t i = 0; i < workspace_lists.size(); ++i) {
  261. auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
  262. MS_EXCEPTION_IF_NULL(device_address);
  263. auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
  264. if (!ret) {
  265. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  266. }
  267. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  268. }
  269. }
  270. }
  271. void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
  272. MS_EXCEPTION_IF_NULL(graph);
  273. MS_EXCEPTION_IF_NULL(mem_manager_);
  274. auto graph_inputs = graph->inputs();
  275. auto graph_valid_input = graph->valid_inputs();
  276. std::vector<AnfNodePtr> need_alloc_nodes;
  277. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  278. auto item = graph_inputs[i];
  279. MS_EXCEPTION_IF_NULL(item);
  280. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  281. continue;
  282. }
  283. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  284. auto outs = AnfAlgo::GetAllOutput(item);
  285. for (auto &out : outs) {
  286. MS_EXCEPTION_IF_NULL(out);
  287. if (!out->isa<Parameter>()) {
  288. continue;
  289. }
  290. if (NodeOutputDeviceAddressExist(out, 0)) {
  291. continue;
  292. }
  293. need_alloc_nodes.push_back(out);
  294. }
  295. }
  296. if (!item->isa<Parameter>()) {
  297. continue;
  298. }
  299. if (NodeOutputDeviceAddressExist(item, 0)) {
  300. continue;
  301. }
  302. need_alloc_nodes.push_back(item);
  303. }
  304. for (auto &item : need_alloc_nodes) {
  305. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  306. for (size_t index = 0; index < output_size; index++) {
  307. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  308. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  309. if (output_type_id == kTypeUnknown) {
  310. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  311. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  312. }
  313. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  314. auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
  315. auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  316. AnfAlgo::SetOutputAddr(address, index, item.get());
  317. }
  318. }
  319. }
  320. void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) {
  321. MS_EXCEPTION_IF_NULL(graph);
  322. auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
  323. std::vector<session::KernelWithIndex> non_communication_op;
  324. // Assign Communicate Op Memory firstly.
  325. for (const auto &node : nodes) {
  326. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  327. MS_EXCEPTION_IF_NULL(item_with_index.first);
  328. if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
  329. continue;
  330. }
  331. graph->AddFinalOutputKernel(item_with_index.first);
  332. if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
  333. AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
  334. } else {
  335. non_communication_op.emplace_back(item_with_index);
  336. }
  337. }
  338. for (const auto &item_with_index : non_communication_op) {
  339. AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
  340. }
  341. }
  342. void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
  343. MS_EXCEPTION_IF_NULL(graph);
  344. auto &kernels = graph->execution_order();
  345. for (auto &kernel : kernels) {
  346. MS_EXCEPTION_IF_NULL(kernel);
  347. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  348. MS_EXCEPTION_IF_NULL(kernel_mod);
  349. auto output_sizes = kernel_mod->GetOutputSizeList();
  350. if (output_sizes.empty()) {
  351. MS_LOG(INFO) << "This kernel has no output size.";
  352. continue;
  353. }
  354. for (size_t i = 0; i < output_sizes.size(); ++i) {
  355. session::AnfWithOutIndex out_pair(kernel, i);
  356. if (graph->IsInRefOutputMap(out_pair)) {
  357. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  358. MS_EXCEPTION_IF_NULL(origin_pair.first);
  359. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  360. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  361. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  362. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  363. MS_LOG(INFO) << "REF address is not same, ref node output need address update";
  364. MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  365. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  366. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  367. }
  368. }
  369. }
  370. }
  371. }
  372. void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
  373. AssignCommunicationNodeInputMem(node);
  374. AssignCommunicationNodeOutputMem(flag, node);
  375. }
  376. void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
  377. MS_EXCEPTION_IF_NULL(node);
  378. MS_EXCEPTION_IF_NULL(mem_manager_);
  379. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  380. MS_EXCEPTION_IF_NULL(kernel_mod);
  381. auto output_sizes = kernel_mod->GetOutputSizeList();
  382. if (output_sizes.empty()) {
  383. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  384. return;
  385. }
  386. auto context_ptr = MsContext::GetInstance();
  387. MS_EXCEPTION_IF_NULL(context_ptr);
  388. size_t total_size = 0;
  389. size_t output_index = 0;
  390. std::vector<size_t> align_size_list;
  391. for (uint64_t mem_size : output_sizes) {
  392. if (AnfAlgo::OutputAddrExist(node, output_index++)) {
  393. MS_LOG(INFO) << "communication op addr exist";
  394. continue;
  395. }
  396. if (context_ptr->enable_hccl()) {
  397. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  398. }
  399. total_size += mem_size;
  400. align_size_list.emplace_back(mem_size);
  401. }
  402. uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
  403. for (size_t j = 0; j < align_size_list.size(); ++j) {
  404. std::string output_format = AnfAlgo::GetOutputFormat(node, j);
  405. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
  406. auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type);
  407. AnfAlgo::SetOutputAddr(address, j, node.get());
  408. output_ptr += align_size_list[j];
  409. }
  410. }
  411. DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
  412. MS_EXCEPTION_IF_NULL(anf_node);
  413. auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
  414. auto output_sizes = kernel_mod->GetOutputSizeList();
  415. if (output_sizes.size() <= index) {
  416. MS_LOG(EXCEPTION) << "Previous node output size < node index";
  417. }
  418. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
  419. auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
  420. auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
  421. AnfAlgo::SetOutputAddr(address, index, anf_node.get());
  422. return address;
  423. }
  424. void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
  425. auto context_ptr = MsContext::GetInstance();
  426. MS_EXCEPTION_IF_NULL(context_ptr);
  427. MS_EXCEPTION_IF_NULL(node);
  428. MS_EXCEPTION_IF_NULL(mem_manager_);
  429. size_t total_size = 0;
  430. std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
  431. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
  432. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
  433. auto input_node = input_node_with_index.first;
  434. DeviceAddressPtr address = nullptr;
  435. if (input_node->isa<CNode>()) {
  436. address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
  437. } else {
  438. MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
  439. }
  440. MS_EXCEPTION_IF_NULL(address);
  441. auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
  442. total_size += mem_size;
  443. addr_size.emplace_back(address.get(), mem_size);
  444. }
  445. uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size);
  446. for (const auto &iter : addr_size) {
  447. MS_EXCEPTION_IF_NULL(iter.first);
  448. iter.first->set_ptr(input_ptr);
  449. input_ptr += iter.second;
  450. }
  451. }
  452. void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
  453. MS_EXCEPTION_IF_NULL(node);
  454. MS_EXCEPTION_IF_NULL(mem_manager_);
  455. if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
  456. MS_LOG(INFO) << "GetNext disable mem_reuse";
  457. flag = kDynamicMem;
  458. }
  459. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  460. MS_EXCEPTION_IF_NULL(kernel_mod);
  461. auto output_sizes = kernel_mod->GetOutputSizeList();
  462. if (output_sizes.empty()) {
  463. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  464. return;
  465. }
  466. for (size_t i = 0; i < output_sizes.size(); ++i) {
  467. if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
  468. continue;
  469. }
  470. if (NodeOutputDeviceAddressExist(node, i)) {
  471. MS_LOG(INFO) << "Already malloc index:" << i;
  472. continue;
  473. }
  474. auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]);
  475. if (ptr == nullptr) {
  476. // reused ptr, no need alloc, continue;
  477. continue;
  478. }
  479. std::string output_format = AnfAlgo::GetOutputFormat(node, i);
  480. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
  481. auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type);
  482. device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
  483. AnfAlgo::SetOutputAddr(device_address, i, node.get());
  484. }
  485. }
  486. void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
  487. size_t output_idx) {
  488. MS_EXCEPTION_IF_NULL(value_node);
  489. MS_EXCEPTION_IF_NULL(node_value);
  490. MS_EXCEPTION_IF_NULL(mem_manager_);
  491. auto ms_context = MsContext::GetInstance();
  492. MS_EXCEPTION_IF_NULL(ms_context);
  493. auto tensor = node_value->cast<TensorPtr>();
  494. if (tensor == nullptr) {
  495. MS_LOG(WARNING) << "Tensor is null";
  496. return;
  497. }
  498. size_t tensor_size = tensor->data().nbytes();
  499. auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
  500. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  501. if (output_type_id == kTypeUnknown) {
  502. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  503. }
  504. auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  505. DeviceAddressPtr address = nullptr;
  506. if (ms_context->enable_pynative_infer()) {
  507. address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
  508. MS_EXCEPTION_IF_NULL(address);
  509. if (!mem_manager_->MallocMemFromMemPool(address, node_size)) {
  510. MS_LOG(EXCEPTION) << "Malloc value node device memory failed !";
  511. }
  512. } else {
  513. auto ptr = mem_manager_->MallocMem(kStaticMem, node_size);
  514. address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id);
  515. MS_EXCEPTION_IF_NULL(address);
  516. }
  517. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  518. if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
  519. tensor->data_c())) {
  520. MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
  521. << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
  522. << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  523. }
  524. }
  525. void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
  526. MS_EXCEPTION_IF_NULL(graph);
  527. MS_EXCEPTION_IF_NULL(mem_manager_);
  528. auto ms_context = MsContext::GetInstance();
  529. MS_EXCEPTION_IF_NULL(ms_context);
  530. for (auto &value_node : graph->graph_value_nodes()) {
  531. MS_EXCEPTION_IF_NULL(value_node);
  532. if (NodeOutputDeviceAddressExist(value_node, 0)) {
  533. MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
  534. continue;
  535. }
  536. auto &node_value = value_node->value();
  537. MS_EXCEPTION_IF_NULL(node_value);
  538. if (node_value->isa<Tensor>()) {
  539. AssignValueNodeTensor(value_node, node_value, 0);
  540. } else if (node_value->isa<StringImm>()) {
  541. auto value = GetValue<std::string>(node_value);
  542. size_t tensor_size = value.size();
  543. DeviceAddressPtr address = nullptr;
  544. if (ms_context->enable_pynative_infer()) {
  545. address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  546. MS_EXCEPTION_IF_NULL(address);
  547. if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
  548. MS_LOG(EXCEPTION) << "Malloc value node device memory failed !";
  549. }
  550. } else {
  551. auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
  552. address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  553. MS_EXCEPTION_IF_NULL(address);
  554. }
  555. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  556. std::vector<int> shape = {1, SizeToInt(tensor_size)};
  557. if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  558. MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
  559. }
  560. }
  561. }
  562. }
  563. void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
  564. MS_EXCEPTION_IF_NULL(graph);
  565. MS_EXCEPTION_IF_NULL(mem_manager_);
  566. auto context_ptr = MsContext::GetInstance();
  567. MS_EXCEPTION_IF_NULL(context_ptr);
  568. bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
  569. auto mem_flag = kDynamicMem;
  570. if (is_enable_mem_reuse) {
  571. mem_manager_->MallocReusedDynamicMem(graph);
  572. mem_flag = kReuseDynamicMem;
  573. }
  574. auto &execution_nodes = graph->execution_order();
  575. std::vector<CNodePtr> compute_nodes;
  576. // communication nodes first
  577. for (auto &node : execution_nodes) {
  578. if (AnfAlgo::IsCommunicationOp(node)) {
  579. // skip if the memory is already alocated
  580. AssignCommunicationNodeMem(mem_flag, node);
  581. } else {
  582. compute_nodes.emplace_back(node);
  583. }
  584. }
  585. // then compute nodes
  586. for (auto &node : compute_nodes) {
  587. AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
  588. AssignWorkSpaceMem(mem_flag, node);
  589. }
  590. }
  591. void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
  592. MS_EXCEPTION_IF_NULL(node);
  593. MS_EXCEPTION_IF_NULL(mem_manager_);
  594. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  595. MS_EXCEPTION_IF_NULL(kernel_mod);
  596. size_t index = 0;
  597. for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
  598. auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size);
  599. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
  600. index++;
  601. }
  602. }
  603. void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
  604. AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
  605. AddressPtrList *kernel_outputs) {
  606. MS_EXCEPTION_IF_NULL(kernel);
  607. MS_EXCEPTION_IF_NULL(kernel_inputs);
  608. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  609. MS_EXCEPTION_IF_NULL(kernel_outputs);
  610. auto cnode = kernel->cast<CNodePtr>();
  611. MS_EXCEPTION_IF_NULL(cnode);
  612. if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
  613. return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
  614. }
  615. auto is_all_nop_node = opt::IsAllNopNode(&graph);
  616. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  617. auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
  618. DeviceAddressPtr device_address;
  619. if (is_all_nop_node) {
  620. device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
  621. } else {
  622. device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
  623. }
  624. MS_EXCEPTION_IF_NULL(device_address);
  625. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  626. MS_EXCEPTION_IF_NULL(input);
  627. input->addr = device_address->ptr_;
  628. MS_EXCEPTION_IF_NULL(input->addr);
  629. input->size = device_address->size_;
  630. kernel_inputs->emplace_back(input);
  631. }
  632. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  633. MS_EXCEPTION_IF_NULL(kernel_mod);
  634. for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
  635. DeviceAddressPtr device_address;
  636. if (is_all_nop_node) {
  637. device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
  638. } else {
  639. device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
  640. }
  641. MS_EXCEPTION_IF_NULL(device_address);
  642. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  643. MS_EXCEPTION_IF_NULL(output);
  644. output->addr = device_address->ptr_;
  645. MS_EXCEPTION_IF_NULL(output->addr);
  646. output->size = device_address->size_;
  647. kernel_outputs->emplace_back(output);
  648. }
  649. for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
  650. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  651. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  652. MS_EXCEPTION_IF_NULL(workspace);
  653. workspace->addr = device_address->ptr_;
  654. MS_EXCEPTION_IF_NULL(workspace->addr);
  655. workspace->size = device_address->size_;
  656. kernel_workspaces->emplace_back(workspace);
  657. }
  658. }
  659. void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
  660. if (cnode->inputs().size() != 2) {
  661. MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
  662. }
  663. MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
  664. auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
  665. // set clean output address
  666. if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
  667. auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
  668. for (auto index : clean_output_indexs) {
  669. auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
  670. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  671. MS_EXCEPTION_IF_NULL(input);
  672. input->addr = device_address->ptr_;
  673. MS_EXCEPTION_IF_NULL(input->addr);
  674. input->size = device_address->size_;
  675. kernel_inputs->emplace_back(input);
  676. }
  677. MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
  678. }
  679. // set clean workspace address
  680. if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
  681. auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
  682. for (const auto &index : clean_workspaces_indexs) {
  683. auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
  684. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  685. MS_EXCEPTION_IF_NULL(workspace);
  686. workspace->addr = device_address->ptr_;
  687. MS_EXCEPTION_IF_NULL(workspace->addr);
  688. workspace->size = device_address->size_;
  689. kernel_inputs->emplace_back(workspace);
  690. }
  691. }
  692. }
  693. bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
  694. auto &kernels = graph.execution_order();
  695. for (const auto &kernel : kernels) {
  696. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  697. MS_EXCEPTION_IF_NULL(kernel_mod);
  698. AddressPtrList kernel_inputs;
  699. AddressPtrList kernel_workspaces;
  700. AddressPtrList kernel_outputs;
  701. GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  702. auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  703. if (!ret) {
  704. MS_LOG(ERROR) << "Launch kernel failed.";
  705. return false;
  706. }
  707. }
  708. return true;
  709. }
  710. bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
  711. MS_EXCEPTION_IF_NULL(graph);
  712. if (!LaunchKernelMod(*graph)) {
  713. MS_LOG(ERROR) << "LaunchKernelMod failed!";
  714. return false;
  715. }
  716. return true;
  717. }
  718. void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
  719. MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
  720. }
  721. #ifdef ENABLE_DUMP_E2E
  722. bool KernelRuntime::SetDumpConf() {
  723. dump_conf_ptr_ = std::make_shared<Dump>();
  724. MS_EXCEPTION_IF_NULL(dump_conf_ptr_);
  725. bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile();
  726. return ret;
  727. }
  728. DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; }
  729. #endif
  730. } // namespace device
  731. } // namespace mindspore