You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime.cc 26 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/kernel_runtime.h"
  17. #include <utility>
  18. #include <numeric>
  19. #include <functional>
  20. #include "common/utils.h"
  21. #include "common/trans.h"
  22. #include "utils/utils.h"
  23. #include "utils/context/ms_context.h"
  24. #include "operator/ops.h"
  25. #include "pipeline/parse/python_adapter.h"
  26. #include "session/kernel_graph.h"
  27. #include "session/anf_runtime_algorithm.h"
  28. #include "kernel/common_utils.h"
  29. #include "kernel/oplib/oplib.h"
  30. #include "ir/value.h"
  31. using mindspore::kernel::Address;
  32. using mindspore::kernel::AddressPtr;
  33. namespace mindspore {
  34. namespace device {
  35. KernelRuntime::~KernelRuntime() {
  36. #ifdef ENABLE_DUMP_E2E
  37. dump_conf_ptr_ = nullptr;
  38. #endif
  39. }
  40. bool KernelRuntime::Run(session::KernelGraph *graph) {
  41. bool ret = false;
  42. auto context_ptr = MsContext::GetInstance();
  43. MS_EXCEPTION_IF_NULL(context_ptr);
  44. #if defined(_WIN32) || defined(_WIN64)
  45. auto start_time = std::chrono::steady_clock::now();
  46. #else
  47. struct timeval start_time, end_time;
  48. (void)gettimeofday(&start_time, nullptr);
  49. #endif
  50. bool is_task_sink = context_ptr->enable_task_sink();
  51. if (is_task_sink) {
  52. ret = RunTask(graph);
  53. } else {
  54. ret = LaunchKernel(graph);
  55. }
  56. #if defined(_WIN32) || defined(_WIN64)
  57. auto end_time = std::chrono::steady_clock::now();
  58. std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
  59. MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
  60. #else
  61. (void)gettimeofday(&end_time, nullptr);
  62. const uint64_t kUSecondInSecond = 1000000;
  63. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  64. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  65. MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
  66. #endif
  67. return ret;
  68. }
  69. // for D to impl
  70. bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
  71. if (graph != nullptr) {
  72. return true;
  73. }
  74. return false;
  75. }
  76. // for D to impl
  77. bool KernelRuntime::GenTask(const session::KernelGraph *graph) {
  78. if (graph != nullptr) {
  79. return true;
  80. }
  81. return false;
  82. }
  83. bool KernelRuntime::LoadTask(const session::KernelGraph *graph) {
  84. if (graph != nullptr) {
  85. return true;
  86. }
  87. return false;
  88. }
  89. // for D to impl
  90. bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
  91. if (graph != nullptr) {
  92. return true;
  93. }
  94. return false;
  95. }
  96. size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
  97. MS_EXCEPTION_IF_NULL(node);
  98. if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
  99. MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
  100. << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
  101. }
  102. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  103. if (output_type_id == kTypeUnknown) {
  104. output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  105. }
  106. size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
  107. std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
  108. auto format = AnfAlgo::GetOutputFormat(node, output_index);
  109. if (shape.empty() && format != kOpFormat_DEFAULT) {
  110. shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
  111. shape = trans::TransShapeToDevice(shape, format);
  112. }
  113. // scalar's output shape is a empty vector
  114. size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
  115. return tensor_size;
  116. }
  117. void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
  118. auto context_ptr = MsContext::GetInstance();
  119. MS_EXCEPTION_IF_NULL(context_ptr);
  120. MS_EXCEPTION_IF_NULL(mem_manager_);
  121. mem_manager_->ResetDynamicMemory();
  122. AssignStaticMemory(graph);
  123. AssignDynamicMemory(graph);
  124. UpdateRefNodeOutputMem(graph);
  125. }
  126. void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  127. session::KernelGraph *graph) {
  128. MS_EXCEPTION_IF_NULL(graph);
  129. // assign memory for input nodes
  130. RunOpAssignInputMemory(input_tensors, graph);
  131. AssignStaticMemoryValueNode(graph);
  132. for (const auto &cnode : graph->execution_order()) {
  133. // assign memory for output nodes
  134. RunOpAssignOutputMemory(cnode);
  135. // assign memory for workspace
  136. RunOpAssignWorkSpaceMemory(cnode);
  137. }
  138. UpdateRefNodeOutputMem(graph);
  139. }
  140. void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) {
  141. AssignStaticMemoryInput(graph);
  142. AssignStaticMemoryValueNode(graph);
  143. AssignStaticMemoryOutput(graph);
  144. }
  145. void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  146. const session::KernelGraph *graph) {
  147. MS_EXCEPTION_IF_NULL(graph);
  148. MS_EXCEPTION_IF_NULL(mem_manager_);
  149. for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) {
  150. auto item = graph->inputs()[input_index];
  151. MS_EXCEPTION_IF_NULL(item);
  152. if (!item->isa<Parameter>()) {
  153. continue;
  154. }
  155. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  156. for (size_t index = 0; index < output_size; index++) {
  157. MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
  158. if (input_tensors[input_index]->device_address().get() != nullptr) {
  159. AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get());
  160. continue;
  161. }
  162. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  163. if (output_type_id == kTypeUnknown) {
  164. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  165. }
  166. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  167. auto device_address =
  168. CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  169. MS_EXCEPTION_IF_NULL(device_address);
  170. auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
  171. if (!ret) {
  172. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  173. }
  174. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  175. }
  176. }
  177. }
  178. void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
  179. MS_EXCEPTION_IF_NULL(kernel);
  180. MS_EXCEPTION_IF_NULL(mem_manager_);
  181. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  182. MS_EXCEPTION_IF_NULL(kernel_mod);
  183. auto output_sizes = kernel_mod->GetOutputSizeList();
  184. if (output_sizes.empty()) {
  185. return;
  186. }
  187. for (size_t i = 0; i < output_sizes.size(); ++i) {
  188. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  189. continue;
  190. }
  191. if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
  192. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  193. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  194. continue;
  195. }
  196. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  197. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  198. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  199. MS_EXCEPTION_IF_NULL(device_address);
  200. auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
  201. if (!ret) {
  202. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  203. }
  204. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  205. }
  206. }
  207. void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
  208. MS_EXCEPTION_IF_NULL(kernel);
  209. MS_EXCEPTION_IF_NULL(mem_manager_);
  210. if (kernel->isa<CNode>()) {
  211. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  212. MS_EXCEPTION_IF_NULL(kernel_mod);
  213. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  214. for (size_t i = 0; i < workspace_lists.size(); ++i) {
  215. auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
  216. MS_EXCEPTION_IF_NULL(device_address);
  217. auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
  218. if (!ret) {
  219. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  220. }
  221. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  222. }
  223. }
  224. }
  225. void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
  226. MS_EXCEPTION_IF_NULL(graph);
  227. MS_EXCEPTION_IF_NULL(mem_manager_);
  228. auto graph_inputs = graph->inputs();
  229. auto graph_valid_input = graph->valid_inputs();
  230. for (size_t i = 0; i < graph_inputs.size(); i++) {
  231. auto item = graph_inputs[i];
  232. MS_EXCEPTION_IF_NULL(item);
  233. if (!item->isa<Parameter>()) {
  234. continue;
  235. }
  236. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  237. continue;
  238. }
  239. if (AnfAlgo::OutputAddrExist(item, 0)) {
  240. continue;
  241. }
  242. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  243. for (size_t index = 0; index < output_size; index++) {
  244. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  245. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  246. if (output_type_id == kTypeUnknown) {
  247. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  248. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  249. }
  250. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  251. auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
  252. auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  253. AnfAlgo::SetOutputAddr(address, index, item.get());
  254. }
  255. }
  256. }
  257. void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
  258. MS_EXCEPTION_IF_NULL(graph);
  259. auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
  260. for (const auto &node : nodes) {
  261. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  262. MS_EXCEPTION_IF_NULL(item_with_index.first);
  263. if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
  264. continue;
  265. }
  266. AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
  267. }
  268. }
  269. void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
  270. MS_EXCEPTION_IF_NULL(graph);
  271. auto &kernels = graph->execution_order();
  272. for (auto &kernel : kernels) {
  273. MS_EXCEPTION_IF_NULL(kernel);
  274. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  275. MS_EXCEPTION_IF_NULL(kernel_mod);
  276. auto output_sizes = kernel_mod->GetOutputSizeList();
  277. if (output_sizes.empty()) {
  278. MS_LOG(INFO) << "This kernel has no output size.";
  279. continue;
  280. }
  281. for (size_t i = 0; i < output_sizes.size(); ++i) {
  282. session::AnfWithOutIndex out_pair(kernel, i);
  283. if (graph->IsInRefOutputMap(out_pair)) {
  284. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  285. MS_EXCEPTION_IF_NULL(origin_pair.first);
  286. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  287. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  288. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  289. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  290. MS_LOG(INFO) << "REF address is not same, ref node output need address update";
  291. MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  292. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  293. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  294. }
  295. }
  296. }
  297. }
  298. }
  299. void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
  300. MS_EXCEPTION_IF_NULL(node);
  301. MS_EXCEPTION_IF_NULL(mem_manager_);
  302. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  303. MS_EXCEPTION_IF_NULL(kernel_mod);
  304. auto output_sizes = kernel_mod->GetOutputSizeList();
  305. if (output_sizes.empty()) {
  306. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  307. return;
  308. }
  309. auto context_ptr = MsContext::GetInstance();
  310. MS_EXCEPTION_IF_NULL(context_ptr);
  311. size_t total_size = 0;
  312. std::vector<size_t> align_size_list;
  313. for (uint64_t mem_size : output_sizes) {
  314. if (context_ptr->enable_hccl()) {
  315. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  316. }
  317. total_size += mem_size;
  318. align_size_list.emplace_back(mem_size);
  319. }
  320. uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
  321. for (size_t j = 0; j < align_size_list.size(); ++j) {
  322. std::string output_format = AnfAlgo::GetOutputFormat(node, j);
  323. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
  324. auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type);
  325. AnfAlgo::SetOutputAddr(address, j, node.get());
  326. output_ptr += align_size_list[j];
  327. }
  328. }
  329. void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
  330. auto context_ptr = MsContext::GetInstance();
  331. MS_EXCEPTION_IF_NULL(context_ptr);
  332. MS_EXCEPTION_IF_NULL(node);
  333. MS_EXCEPTION_IF_NULL(mem_manager_);
  334. size_t total_size = 0;
  335. std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
  336. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
  337. auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i);
  338. MS_EXCEPTION_IF_NULL(address);
  339. auto mem_size = address->size();
  340. if (context_ptr->enable_hccl()) {
  341. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  342. }
  343. total_size += mem_size;
  344. addr_size.emplace_back(address.get(), mem_size);
  345. }
  346. uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size);
  347. for (const auto &iter : addr_size) {
  348. MS_EXCEPTION_IF_NULL(iter.first);
  349. iter.first->set_ptr(input_ptr);
  350. input_ptr += iter.second;
  351. }
  352. }
  353. void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
  354. MS_EXCEPTION_IF_NULL(node);
  355. MS_EXCEPTION_IF_NULL(mem_manager_);
  356. if (AnfAlgo::IsCommunicationOp(node)) {
  357. UpdateCommunicationOpInputMem(node);
  358. AssignCommunicationNodeOutputMem(flag, node);
  359. return;
  360. }
  361. if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
  362. MS_LOG(INFO) << "GetNext disable mem_reuse";
  363. flag = kDynamicMem;
  364. }
  365. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  366. MS_EXCEPTION_IF_NULL(kernel_mod);
  367. auto output_sizes = kernel_mod->GetOutputSizeList();
  368. if (output_sizes.empty()) {
  369. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  370. return;
  371. }
  372. for (size_t i = 0; i < output_sizes.size(); ++i) {
  373. if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
  374. continue;
  375. }
  376. if (AnfAlgo::OutputAddrExist(node, i)) {
  377. MS_LOG(INFO) << "Already malloc index:" << i;
  378. continue;
  379. }
  380. auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]);
  381. if (ptr == nullptr) {
  382. // reused ptr, no need alloc, continue;
  383. continue;
  384. }
  385. std::string output_format = AnfAlgo::GetOutputFormat(node, i);
  386. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
  387. AnfAlgo::SetOutputAddr(CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type), i, node.get());
  388. }
  389. }
  390. void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
  391. size_t output_idx) {
  392. MS_EXCEPTION_IF_NULL(value_node);
  393. MS_EXCEPTION_IF_NULL(node_value);
  394. MS_EXCEPTION_IF_NULL(mem_manager_);
  395. auto tensor = node_value->cast<TensorPtr>();
  396. if (tensor == nullptr) {
  397. MS_LOG(WARNING) << "Tensor is null";
  398. return;
  399. }
  400. size_t tensor_size = tensor->data().nbytes();
  401. auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
  402. auto ptr = mem_manager_->MallocMem(kStaticMem, node_size);
  403. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  404. if (output_type_id == kTypeUnknown) {
  405. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  406. }
  407. auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id);
  408. MS_EXCEPTION_IF_NULL(address);
  409. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  410. if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
  411. tensor->data_c(false))) {
  412. MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
  413. << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
  414. << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  415. }
  416. }
  417. void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
  418. MS_EXCEPTION_IF_NULL(graph);
  419. MS_EXCEPTION_IF_NULL(mem_manager_);
  420. for (auto &value_node : graph->graph_value_nodes()) {
  421. MS_EXCEPTION_IF_NULL(value_node);
  422. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  423. MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
  424. continue;
  425. }
  426. auto &node_value = value_node->value();
  427. MS_EXCEPTION_IF_NULL(node_value);
  428. if (node_value->isa<Tensor>()) {
  429. AssignValueNodeTensor(value_node, node_value, 0);
  430. } else if (node_value->isa<StringImm>()) {
  431. auto value = GetValue<std::string>(node_value);
  432. size_t tensor_size = value.size();
  433. auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
  434. auto address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  435. MS_EXCEPTION_IF_NULL(address);
  436. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  437. std::vector<int> shape = {1, SizeToInt(tensor_size)};
  438. if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  439. MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
  440. }
  441. }
  442. }
  443. }
  444. void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
  445. MS_EXCEPTION_IF_NULL(graph);
  446. MS_EXCEPTION_IF_NULL(mem_manager_);
  447. auto context_ptr = MsContext::GetInstance();
  448. MS_EXCEPTION_IF_NULL(context_ptr);
  449. bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
  450. auto mem_flag = kDynamicMem;
  451. if (is_enable_mem_reuse) {
  452. mem_manager_->MallocReusedDynamicMem(graph);
  453. mem_flag = kReuseDynamicMem;
  454. }
  455. auto &kernels = graph->execution_order();
  456. for (auto &kernel : kernels) {
  457. AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts);
  458. AssignWorkSpaceMem(mem_flag, kernel);
  459. }
  460. }
  461. void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
  462. MS_EXCEPTION_IF_NULL(node);
  463. MS_EXCEPTION_IF_NULL(mem_manager_);
  464. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  465. MS_EXCEPTION_IF_NULL(kernel_mod);
  466. size_t index = 0;
  467. for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
  468. auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size);
  469. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
  470. index++;
  471. }
  472. }
  473. void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
  474. AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
  475. AddressPtrList *kernel_outputs) {
  476. MS_EXCEPTION_IF_NULL(kernel);
  477. MS_EXCEPTION_IF_NULL(kernel_inputs);
  478. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  479. MS_EXCEPTION_IF_NULL(kernel_outputs);
  480. auto cnode = kernel->cast<CNodePtr>();
  481. MS_EXCEPTION_IF_NULL(cnode);
  482. if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
  483. return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
  484. }
  485. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  486. auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
  487. auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
  488. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  489. MS_EXCEPTION_IF_NULL(input);
  490. input->addr = device_address->ptr_;
  491. MS_EXCEPTION_IF_NULL(input->addr);
  492. input->size = device_address->size_;
  493. kernel_inputs->emplace_back(input);
  494. }
  495. for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
  496. auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
  497. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  498. MS_EXCEPTION_IF_NULL(output);
  499. output->addr = device_address->ptr_;
  500. MS_EXCEPTION_IF_NULL(output->addr);
  501. output->size = device_address->size_;
  502. kernel_outputs->emplace_back(output);
  503. }
  504. for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
  505. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  506. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  507. MS_EXCEPTION_IF_NULL(workspace);
  508. workspace->addr = device_address->ptr_;
  509. MS_EXCEPTION_IF_NULL(workspace->addr);
  510. workspace->size = device_address->size_;
  511. kernel_workspaces->emplace_back(workspace);
  512. }
  513. }
  514. void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
  515. if (cnode->inputs().size() != 2) {
  516. MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
  517. }
  518. auto pre_node = cnode->inputs()[1];
  519. // set clean output address
  520. if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) {
  521. auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs);
  522. for (auto index : clean_output_indexs) {
  523. auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
  524. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  525. MS_EXCEPTION_IF_NULL(input);
  526. input->addr = device_address->ptr_;
  527. MS_EXCEPTION_IF_NULL(input->addr);
  528. input->size = device_address->size_;
  529. kernel_inputs->emplace_back(input);
  530. }
  531. MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
  532. }
  533. // set clean workspace address
  534. if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {
  535. auto clean_workspaces = AnfAlgo::GetNodeAttr<int>(pre_node, kAttrAutomicWorkspaceSize);
  536. if (clean_workspaces != 0) {
  537. auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, 0);
  538. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  539. MS_EXCEPTION_IF_NULL(workspace);
  540. workspace->addr = device_address->ptr_;
  541. MS_EXCEPTION_IF_NULL(workspace->addr);
  542. workspace->size = device_address->size_;
  543. kernel_inputs->emplace_back(workspace);
  544. }
  545. MS_LOG(INFO) << "AtomicAddClean clean workspace size" << clean_workspaces;
  546. }
  547. }
  548. bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
  549. auto &kernels = graph.execution_order();
  550. for (const auto &kernel : kernels) {
  551. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  552. MS_EXCEPTION_IF_NULL(kernel_mod);
  553. AddressPtrList kernel_inputs;
  554. AddressPtrList kernel_workspaces;
  555. AddressPtrList kernel_outputs;
  556. GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  557. #if defined(_WIN32) || defined(_WIN64)
  558. auto start_time = std::chrono::steady_clock::now();
  559. #else
  560. struct timeval start_time, end_time;
  561. (void)gettimeofday(&start_time, nullptr);
  562. #endif
  563. auto ret =
  564. kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, reinterpret_cast<uintptr_t>(stream_));
  565. if (!ret) {
  566. MS_LOG(ERROR) << "Launch kernel failed.";
  567. return false;
  568. } else {
  569. if (AnfAlgo::GetKernelType(kernel) == TBE_KERNEL && !SyncStream()) {
  570. MS_LOG(EXCEPTION) << "SyncStream failed.";
  571. }
  572. #if defined(_WIN32) || defined(_WIN64)
  573. auto end_time = std::chrono::steady_clock::now();
  574. std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
  575. MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost.count() << " us";
  576. #else
  577. (void)gettimeofday(&end_time, nullptr);
  578. const uint64_t kUSecondInSecond = 1000000;
  579. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  580. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  581. MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost << " us";
  582. #endif
  583. }
  584. }
  585. return true;
  586. }
  587. bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
  588. MS_EXCEPTION_IF_NULL(graph);
  589. if (!LaunchKernelMod(*graph)) {
  590. MS_LOG(ERROR) << "LaunchKernelMod failed!";
  591. return false;
  592. }
  593. if (!SyncStream()) {
  594. MS_LOG(ERROR) << "SyncStream failed!";
  595. return false;
  596. }
  597. return true;
  598. }
  599. #ifdef ENABLE_DUMP_E2E
  600. bool KernelRuntime::SetDumpConf() {
  601. dump_conf_ptr_ = std::make_shared<Dump>();
  602. MS_EXCEPTION_IF_NULL(dump_conf_ptr_);
  603. bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile();
  604. return ret;
  605. }
  606. DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; }
  607. #endif
  608. } // namespace device
  609. } // namespace mindspore