You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime.cc 28 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/kernel_runtime.h"
  17. #include <vector>
  18. #include <utility>
  19. #include <numeric>
  20. #include <functional>
  21. #include "common/utils.h"
  22. #include "common/trans.h"
  23. #include "utils/utils.h"
  24. #include "utils/context/ms_context.h"
  25. #include "operator/ops.h"
  26. #include "pipeline/parse/python_adapter.h"
  27. #include "session/kernel_graph.h"
  28. #include "session/anf_runtime_algorithm.h"
  29. #include "kernel/common_utils.h"
  30. #include "kernel/oplib/oplib.h"
  31. #include "ir/value.h"
  32. using mindspore::kernel::Address;
  33. using mindspore::kernel::AddressPtr;
  34. namespace mindspore {
  35. namespace device {
  36. KernelRuntime::~KernelRuntime() {
  37. #ifdef ENABLE_DUMP_E2E
  38. dump_conf_ptr_ = nullptr;
  39. #endif
  40. }
  41. bool KernelRuntime::Run(session::KernelGraph *graph) {
  42. bool ret = false;
  43. auto context_ptr = MsContext::GetInstance();
  44. MS_EXCEPTION_IF_NULL(context_ptr);
  45. #if defined(_WIN32) || defined(_WIN64)
  46. auto start_time = std::chrono::steady_clock::now();
  47. #else
  48. struct timeval start_time, end_time;
  49. (void)gettimeofday(&start_time, nullptr);
  50. #endif
  51. bool is_task_sink = context_ptr->enable_task_sink();
  52. if (is_task_sink) {
  53. ret = RunTask(graph);
  54. } else {
  55. ret = LaunchKernel(graph);
  56. }
  57. #if defined(_WIN32) || defined(_WIN64)
  58. auto end_time = std::chrono::steady_clock::now();
  59. std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
  60. MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
  61. #else
  62. (void)gettimeofday(&end_time, nullptr);
  63. const uint64_t kUSecondInSecond = 1000000;
  64. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  65. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  66. MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
  67. #endif
  68. return ret;
  69. }
  70. // for D to impl
  71. bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
  72. if (graph != nullptr) {
  73. return true;
  74. }
  75. return false;
  76. }
  77. // for D to impl
  78. bool KernelRuntime::GenTask(const session::KernelGraph *graph) {
  79. if (graph != nullptr) {
  80. return true;
  81. }
  82. return false;
  83. }
  84. bool KernelRuntime::LoadTask(const session::KernelGraph *graph) {
  85. if (graph != nullptr) {
  86. return true;
  87. }
  88. return false;
  89. }
  90. // for D to impl
  91. bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
  92. if (graph != nullptr) {
  93. return true;
  94. }
  95. return false;
  96. }
  97. bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
  98. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  99. return true;
  100. }
  101. return false;
  102. }
  103. size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
  104. MS_EXCEPTION_IF_NULL(node);
  105. if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
  106. MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
  107. << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
  108. }
  109. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  110. if (output_type_id == kTypeUnknown) {
  111. output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  112. }
  113. size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
  114. std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
  115. auto format = AnfAlgo::GetOutputFormat(node, output_index);
  116. if (shape.empty() && format != kOpFormat_DEFAULT) {
  117. shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
  118. shape = trans::TransShapeToDevice(shape, format);
  119. }
  120. // scalar's output shape is a empty vector
  121. size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
  122. return tensor_size;
  123. }
  124. void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
  125. auto context_ptr = MsContext::GetInstance();
  126. MS_EXCEPTION_IF_NULL(context_ptr);
  127. MS_EXCEPTION_IF_NULL(mem_manager_);
  128. mem_manager_->ResetDynamicMemory();
  129. AssignStaticMemory(graph);
  130. AssignDynamicMemory(graph);
  131. UpdateRefNodeOutputMem(graph);
  132. }
  133. void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  134. session::KernelGraph *graph) {
  135. MS_EXCEPTION_IF_NULL(graph);
  136. RunOpAssignInputMemory(input_tensors, graph);
  137. AssignStaticMemoryValueNode(graph);
  138. for (const auto &cnode : graph->execution_order()) {
  139. RunOpAssignOutputMemory(cnode);
  140. RunOpAssignWorkSpaceMemory(cnode);
  141. }
  142. UpdateRefNodeOutputMem(graph);
  143. }
  144. void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) {
  145. AssignStaticMemoryInput(graph);
  146. AssignStaticMemoryValueNode(graph);
  147. AssignStaticMemoryOutput(graph);
  148. }
  149. void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  150. const session::KernelGraph *graph) {
  151. MS_EXCEPTION_IF_NULL(graph);
  152. MS_EXCEPTION_IF_NULL(mem_manager_);
  153. if (input_tensors.size() != graph->inputs().size()) {
  154. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
  155. << " should be equal to graph input parameter size " << graph->inputs().size();
  156. }
  157. for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) {
  158. auto item = graph->inputs()[input_index];
  159. MS_EXCEPTION_IF_NULL(item);
  160. if (!item->isa<Parameter>()) {
  161. continue;
  162. }
  163. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  164. for (size_t index = 0; index < output_size; index++) {
  165. MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
  166. if (input_tensors[input_index]->device_address().get() != nullptr) {
  167. AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get());
  168. continue;
  169. }
  170. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  171. if (output_type_id == kTypeUnknown) {
  172. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  173. }
  174. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  175. auto device_address =
  176. CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  177. MS_EXCEPTION_IF_NULL(device_address);
  178. auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
  179. if (!ret) {
  180. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  181. }
  182. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  183. }
  184. }
  185. }
  186. void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
  187. MS_EXCEPTION_IF_NULL(kernel);
  188. MS_EXCEPTION_IF_NULL(mem_manager_);
  189. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  190. MS_EXCEPTION_IF_NULL(kernel_mod);
  191. auto output_sizes = kernel_mod->GetOutputSizeList();
  192. if (output_sizes.empty()) {
  193. return;
  194. }
  195. for (size_t i = 0; i < output_sizes.size(); ++i) {
  196. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  197. continue;
  198. }
  199. if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
  200. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  201. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  202. continue;
  203. }
  204. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  205. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  206. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  207. MS_EXCEPTION_IF_NULL(device_address);
  208. auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
  209. if (!ret) {
  210. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  211. }
  212. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  213. }
  214. }
  215. void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
  216. MS_EXCEPTION_IF_NULL(kernel);
  217. MS_EXCEPTION_IF_NULL(mem_manager_);
  218. if (kernel->isa<CNode>()) {
  219. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  220. MS_EXCEPTION_IF_NULL(kernel_mod);
  221. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  222. for (size_t i = 0; i < workspace_lists.size(); ++i) {
  223. auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
  224. MS_EXCEPTION_IF_NULL(device_address);
  225. auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
  226. if (!ret) {
  227. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  228. }
  229. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  230. }
  231. }
  232. }
  233. void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
  234. MS_EXCEPTION_IF_NULL(graph);
  235. MS_EXCEPTION_IF_NULL(mem_manager_);
  236. auto graph_inputs = graph->inputs();
  237. auto graph_valid_input = graph->valid_inputs();
  238. for (size_t i = 0; i < graph_inputs.size(); i++) {
  239. auto item = graph_inputs[i];
  240. MS_EXCEPTION_IF_NULL(item);
  241. if (!item->isa<Parameter>()) {
  242. continue;
  243. }
  244. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  245. continue;
  246. }
  247. if (NodeOutputDeviceAddressExist(item, 0)) {
  248. continue;
  249. }
  250. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  251. for (size_t index = 0; index < output_size; index++) {
  252. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  253. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  254. if (output_type_id == kTypeUnknown) {
  255. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  256. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  257. }
  258. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  259. auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
  260. auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  261. AnfAlgo::SetOutputAddr(address, index, item.get());
  262. }
  263. }
  264. }
  265. void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
  266. MS_EXCEPTION_IF_NULL(graph);
  267. auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
  268. std::vector<session::KernelWithIndex> non_communication_op;
  269. // Assign Communicate Op Memory firstly.
  270. for (const auto &node : nodes) {
  271. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  272. MS_EXCEPTION_IF_NULL(item_with_index.first);
  273. if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
  274. continue;
  275. }
  276. if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
  277. AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
  278. } else {
  279. non_communication_op.emplace_back(item_with_index);
  280. }
  281. }
  282. for (const auto &item_with_index : non_communication_op) {
  283. AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
  284. }
  285. }
  286. void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
  287. MS_EXCEPTION_IF_NULL(graph);
  288. auto &kernels = graph->execution_order();
  289. for (auto &kernel : kernels) {
  290. MS_EXCEPTION_IF_NULL(kernel);
  291. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  292. MS_EXCEPTION_IF_NULL(kernel_mod);
  293. auto output_sizes = kernel_mod->GetOutputSizeList();
  294. if (output_sizes.empty()) {
  295. MS_LOG(INFO) << "This kernel has no output size.";
  296. continue;
  297. }
  298. for (size_t i = 0; i < output_sizes.size(); ++i) {
  299. session::AnfWithOutIndex out_pair(kernel, i);
  300. if (graph->IsInRefOutputMap(out_pair)) {
  301. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  302. MS_EXCEPTION_IF_NULL(origin_pair.first);
  303. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  304. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  305. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  306. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  307. MS_LOG(INFO) << "REF address is not same, ref node output need address update";
  308. MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  309. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  310. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  311. }
  312. }
  313. }
  314. }
  315. }
  316. void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
  317. AssignCommunicationNodeInputMem(node);
  318. AssignCommunicationNodeOutputMem(flag, node);
  319. }
  320. void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
  321. MS_EXCEPTION_IF_NULL(node);
  322. MS_EXCEPTION_IF_NULL(mem_manager_);
  323. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  324. MS_EXCEPTION_IF_NULL(kernel_mod);
  325. auto output_sizes = kernel_mod->GetOutputSizeList();
  326. if (output_sizes.empty()) {
  327. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  328. return;
  329. }
  330. auto context_ptr = MsContext::GetInstance();
  331. MS_EXCEPTION_IF_NULL(context_ptr);
  332. size_t total_size = 0;
  333. size_t output_index = 0;
  334. std::vector<size_t> align_size_list;
  335. for (uint64_t mem_size : output_sizes) {
  336. if (AnfAlgo::OutputAddrExist(node, output_index++)) {
  337. MS_LOG(INFO) << "communication op addr exist";
  338. continue;
  339. }
  340. if (context_ptr->enable_hccl()) {
  341. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  342. }
  343. total_size += mem_size;
  344. align_size_list.emplace_back(mem_size);
  345. }
  346. uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
  347. for (size_t j = 0; j < align_size_list.size(); ++j) {
  348. std::string output_format = AnfAlgo::GetOutputFormat(node, j);
  349. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
  350. auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type);
  351. AnfAlgo::SetOutputAddr(address, j, node.get());
  352. output_ptr += align_size_list[j];
  353. }
  354. }
  355. DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
  356. MS_EXCEPTION_IF_NULL(anf_node);
  357. auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
  358. auto output_sizes = kernel_mod->GetOutputSizeList();
  359. if (output_sizes.size() <= index) {
  360. MS_LOG(EXCEPTION) << "Previous node output size < node index";
  361. }
  362. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
  363. auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
  364. auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
  365. AnfAlgo::SetOutputAddr(address, index, anf_node.get());
  366. return address;
  367. }
  368. void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
  369. auto context_ptr = MsContext::GetInstance();
  370. MS_EXCEPTION_IF_NULL(context_ptr);
  371. MS_EXCEPTION_IF_NULL(node);
  372. MS_EXCEPTION_IF_NULL(mem_manager_);
  373. size_t total_size = 0;
  374. std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
  375. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
  376. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
  377. auto input_node = input_node_with_index.first;
  378. DeviceAddressPtr address = nullptr;
  379. if (input_node->isa<CNode>()) {
  380. address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
  381. } else {
  382. MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
  383. }
  384. MS_EXCEPTION_IF_NULL(address);
  385. auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
  386. total_size += mem_size;
  387. addr_size.emplace_back(address.get(), mem_size);
  388. }
  389. uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size);
  390. for (const auto &iter : addr_size) {
  391. MS_EXCEPTION_IF_NULL(iter.first);
  392. iter.first->set_ptr(input_ptr);
  393. input_ptr += iter.second;
  394. }
  395. }
  396. void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
  397. MS_EXCEPTION_IF_NULL(node);
  398. MS_EXCEPTION_IF_NULL(mem_manager_);
  399. if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
  400. MS_LOG(INFO) << "GetNext disable mem_reuse";
  401. flag = kDynamicMem;
  402. }
  403. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  404. MS_EXCEPTION_IF_NULL(kernel_mod);
  405. auto output_sizes = kernel_mod->GetOutputSizeList();
  406. if (output_sizes.empty()) {
  407. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  408. return;
  409. }
  410. for (size_t i = 0; i < output_sizes.size(); ++i) {
  411. if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
  412. continue;
  413. }
  414. if (NodeOutputDeviceAddressExist(node, i)) {
  415. MS_LOG(INFO) << "Already malloc index:" << i;
  416. continue;
  417. }
  418. auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]);
  419. if (ptr == nullptr) {
  420. // reused ptr, no need alloc, continue;
  421. continue;
  422. }
  423. std::string output_format = AnfAlgo::GetOutputFormat(node, i);
  424. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
  425. AnfAlgo::SetOutputAddr(CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type), i, node.get());
  426. }
  427. }
  428. void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
  429. size_t output_idx) {
  430. MS_EXCEPTION_IF_NULL(value_node);
  431. MS_EXCEPTION_IF_NULL(node_value);
  432. MS_EXCEPTION_IF_NULL(mem_manager_);
  433. auto ms_context = MsContext::GetInstance();
  434. MS_EXCEPTION_IF_NULL(ms_context);
  435. auto tensor = node_value->cast<TensorPtr>();
  436. if (tensor == nullptr) {
  437. MS_LOG(WARNING) << "Tensor is null";
  438. return;
  439. }
  440. size_t tensor_size = tensor->data().nbytes();
  441. auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
  442. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  443. if (output_type_id == kTypeUnknown) {
  444. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  445. }
  446. auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  447. DeviceAddressPtr address = nullptr;
  448. if (ms_context->enable_pynative_infer()) {
  449. address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
  450. MS_EXCEPTION_IF_NULL(address);
  451. if (!mem_manager_->MallocMemFromMemPool(address, node_size)) {
  452. MS_LOG(EXCEPTION) << "Malloc value node device memory failed !";
  453. }
  454. } else {
  455. auto ptr = mem_manager_->MallocMem(kStaticMem, node_size);
  456. address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id);
  457. MS_EXCEPTION_IF_NULL(address);
  458. }
  459. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  460. if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
  461. tensor->data_c(false))) {
  462. MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
  463. << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
  464. << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  465. }
  466. }
  467. void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
  468. MS_EXCEPTION_IF_NULL(graph);
  469. MS_EXCEPTION_IF_NULL(mem_manager_);
  470. auto ms_context = MsContext::GetInstance();
  471. MS_EXCEPTION_IF_NULL(ms_context);
  472. for (auto &value_node : graph->graph_value_nodes()) {
  473. MS_EXCEPTION_IF_NULL(value_node);
  474. if (NodeOutputDeviceAddressExist(value_node, 0)) {
  475. MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
  476. continue;
  477. }
  478. auto &node_value = value_node->value();
  479. MS_EXCEPTION_IF_NULL(node_value);
  480. if (node_value->isa<Tensor>()) {
  481. AssignValueNodeTensor(value_node, node_value, 0);
  482. } else if (node_value->isa<StringImm>()) {
  483. auto value = GetValue<std::string>(node_value);
  484. size_t tensor_size = value.size();
  485. DeviceAddressPtr address = nullptr;
  486. if (ms_context->enable_pynative_infer()) {
  487. address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  488. MS_EXCEPTION_IF_NULL(address);
  489. if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
  490. MS_LOG(EXCEPTION) << "Malloc value node device memory failed !";
  491. }
  492. } else {
  493. auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
  494. address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  495. MS_EXCEPTION_IF_NULL(address);
  496. }
  497. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  498. std::vector<int> shape = {1, SizeToInt(tensor_size)};
  499. if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  500. MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
  501. }
  502. }
  503. }
  504. }
  505. void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
  506. MS_EXCEPTION_IF_NULL(graph);
  507. MS_EXCEPTION_IF_NULL(mem_manager_);
  508. auto context_ptr = MsContext::GetInstance();
  509. MS_EXCEPTION_IF_NULL(context_ptr);
  510. bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
  511. auto mem_flag = kDynamicMem;
  512. if (is_enable_mem_reuse) {
  513. mem_manager_->MallocReusedDynamicMem(graph);
  514. mem_flag = kReuseDynamicMem;
  515. }
  516. auto &execution_nodes = graph->execution_order();
  517. std::vector<CNodePtr> compute_nodes;
  518. // communication nodes first
  519. for (auto &node : execution_nodes) {
  520. if (AnfAlgo::IsCommunicationOp(node)) {
  521. // skip if the memory is already alocated
  522. AssignCommunicationNodeMem(mem_flag, node);
  523. } else {
  524. compute_nodes.emplace_back(node);
  525. }
  526. }
  527. // then compute nodes
  528. for (auto &node : compute_nodes) {
  529. AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
  530. AssignWorkSpaceMem(mem_flag, node);
  531. }
  532. }
  533. void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
  534. MS_EXCEPTION_IF_NULL(node);
  535. MS_EXCEPTION_IF_NULL(mem_manager_);
  536. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  537. MS_EXCEPTION_IF_NULL(kernel_mod);
  538. size_t index = 0;
  539. for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
  540. auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size);
  541. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
  542. index++;
  543. }
  544. }
  545. void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
  546. AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
  547. AddressPtrList *kernel_outputs) {
  548. MS_EXCEPTION_IF_NULL(kernel);
  549. MS_EXCEPTION_IF_NULL(kernel_inputs);
  550. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  551. MS_EXCEPTION_IF_NULL(kernel_outputs);
  552. auto cnode = kernel->cast<CNodePtr>();
  553. MS_EXCEPTION_IF_NULL(cnode);
  554. if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
  555. return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
  556. }
  557. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  558. auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
  559. auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
  560. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  561. MS_EXCEPTION_IF_NULL(input);
  562. input->addr = device_address->ptr_;
  563. MS_EXCEPTION_IF_NULL(input->addr);
  564. input->size = device_address->size_;
  565. kernel_inputs->emplace_back(input);
  566. }
  567. for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
  568. auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
  569. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  570. MS_EXCEPTION_IF_NULL(output);
  571. output->addr = device_address->ptr_;
  572. MS_EXCEPTION_IF_NULL(output->addr);
  573. output->size = device_address->size_;
  574. kernel_outputs->emplace_back(output);
  575. }
  576. for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
  577. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  578. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  579. MS_EXCEPTION_IF_NULL(workspace);
  580. workspace->addr = device_address->ptr_;
  581. MS_EXCEPTION_IF_NULL(workspace->addr);
  582. workspace->size = device_address->size_;
  583. kernel_workspaces->emplace_back(workspace);
  584. }
  585. }
  586. void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
  587. if (cnode->inputs().size() != 2) {
  588. MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
  589. }
  590. MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
  591. auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
  592. // set clean output address
  593. if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) {
  594. auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs);
  595. for (auto index : clean_output_indexs) {
  596. auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
  597. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  598. MS_EXCEPTION_IF_NULL(input);
  599. input->addr = device_address->ptr_;
  600. MS_EXCEPTION_IF_NULL(input->addr);
  601. input->size = device_address->size_;
  602. kernel_inputs->emplace_back(input);
  603. }
  604. MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
  605. }
  606. // set clean workspace address
  607. if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {
  608. auto clean_workspaces = AnfAlgo::GetNodeAttr<int>(pre_node, kAttrAutomicWorkspaceSize);
  609. if (clean_workspaces != 0) {
  610. auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, 0);
  611. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  612. MS_EXCEPTION_IF_NULL(workspace);
  613. workspace->addr = device_address->ptr_;
  614. MS_EXCEPTION_IF_NULL(workspace->addr);
  615. workspace->size = device_address->size_;
  616. kernel_inputs->emplace_back(workspace);
  617. }
  618. MS_LOG(INFO) << "AtomicAddClean clean workspace size" << clean_workspaces;
  619. }
  620. }
  621. bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
  622. auto &kernels = graph.execution_order();
  623. for (const auto &kernel : kernels) {
  624. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  625. MS_EXCEPTION_IF_NULL(kernel_mod);
  626. AddressPtrList kernel_inputs;
  627. AddressPtrList kernel_workspaces;
  628. AddressPtrList kernel_outputs;
  629. GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  630. auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  631. if (!ret) {
  632. MS_LOG(ERROR) << "Launch kernel failed.";
  633. return false;
  634. }
  635. }
  636. return true;
  637. }
  638. bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
  639. MS_EXCEPTION_IF_NULL(graph);
  640. if (!LaunchKernelMod(*graph)) {
  641. MS_LOG(ERROR) << "LaunchKernelMod failed!";
  642. return false;
  643. }
  644. return true;
  645. }
  646. void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
  647. MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
  648. }
  649. #ifdef ENABLE_DUMP_E2E
  650. bool KernelRuntime::SetDumpConf() {
  651. dump_conf_ptr_ = std::make_shared<Dump>();
  652. MS_EXCEPTION_IF_NULL(dump_conf_ptr_);
  653. bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile();
  654. return ret;
  655. }
  656. DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; }
  657. #endif
  658. } // namespace device
  659. } // namespace mindspore