You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime.cc 35 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/kernel_runtime.h"
  17. #include <vector>
  18. #include <utility>
  19. #include <numeric>
  20. #include "utils/ms_utils.h"
  21. #include "common/trans.h"
  22. #include "utils/utils.h"
  23. #include "utils/ms_context.h"
  24. #include "frontend/operator/ops.h"
  25. #include "backend/session/kernel_graph.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "backend/optimizer/common/helper.h"
  28. #include "ir/value.h"
  29. using mindspore::kernel::Address;
  30. using mindspore::kernel::AddressPtr;
  31. namespace mindspore {
  32. namespace device {
  33. KernelRuntime::~KernelRuntime() {
  34. #ifdef ENABLE_DUMP_E2E
  35. dump_conf_ptr_ = nullptr;
  36. #endif
  37. }
  38. bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
  39. bool ret = false;
  40. auto context_ptr = MsContext::GetInstance();
  41. MS_EXCEPTION_IF_NULL(context_ptr);
  42. #if defined(_WIN32) || defined(_WIN64)
  43. auto start_time = std::chrono::steady_clock::now();
  44. #else
  45. struct timeval start_time, end_time;
  46. (void)gettimeofday(&start_time, nullptr);
  47. #endif
  48. bool is_task_sink = context_ptr->enable_task_sink();
  49. if (is_task_sink) {
  50. ret = RunTask(graph);
  51. } else {
  52. ret = LaunchKernel(graph);
  53. }
  54. #if defined(_WIN32) || defined(_WIN64)
  55. auto end_time = std::chrono::steady_clock::now();
  56. std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
  57. MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
  58. #else
  59. (void)gettimeofday(&end_time, nullptr);
  60. const uint64_t kUSecondInSecond = 1000000;
  61. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  62. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  63. MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
  64. #endif
  65. return ret;
  66. }
  67. // for D to impl
  68. bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
  69. if (graph != nullptr) {
  70. return true;
  71. }
  72. return false;
  73. }
  74. // for D to impl
  75. bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
  76. if (graph != nullptr) {
  77. return true;
  78. }
  79. return false;
  80. }
  81. // for D to impl
  82. bool KernelRuntime::GenTask(const session::KernelGraph *graph) {
  83. if (graph != nullptr) {
  84. return true;
  85. }
  86. return false;
  87. }
  88. bool KernelRuntime::LoadTask(const session::KernelGraph *graph) {
  89. if (graph != nullptr) {
  90. return true;
  91. }
  92. return false;
  93. }
  94. // for D to impl
  95. bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
  96. if (graph != nullptr) {
  97. return true;
  98. }
  99. return false;
  100. }
  101. bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
  102. MS_EXCEPTION_IF_NULL(kernel);
  103. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  104. return true;
  105. }
  106. return false;
  107. }
  108. size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
  109. MS_EXCEPTION_IF_NULL(node);
  110. if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
  111. MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
  112. << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
  113. }
  114. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  115. if (output_type_id == kTypeUnknown) {
  116. output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  117. }
  118. size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
  119. std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
  120. auto format = AnfAlgo::GetOutputFormat(node, output_index);
  121. if (shape.empty() && format != kOpFormat_DEFAULT) {
  122. shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
  123. shape = trans::TransShapeToDevice(shape, format);
  124. }
  125. // scalar's output shape is a empty vector
  126. size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
  127. return tensor_size;
  128. }
  129. void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
  130. auto context_ptr = MsContext::GetInstance();
  131. MS_EXCEPTION_IF_NULL(context_ptr);
  132. MS_EXCEPTION_IF_NULL(mem_manager_);
  133. mem_manager_->ResetDynamicMemory();
  134. AssignStaticMemory(graph);
  135. AssignDynamicMemory(graph);
  136. UpdateRefNodeOutputMem(graph);
  137. }
  138. void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value,
  139. const std::vector<tensor::TensorPtr> &input_tensors,
  140. session::KernelGraph *graph) {
  141. MS_EXCEPTION_IF_NULL(graph);
  142. RunOpAssignInputMemory(input_tensors, graph);
  143. AssignStaticMemoryValueNode(graph);
  144. RunOpAssignOutputNodeMemory(pre_output_value, graph);
  145. for (const auto &cnode : graph->execution_order()) {
  146. RunOpAssignOutputMemory(cnode);
  147. RunOpAssignWorkSpaceMemory(cnode);
  148. }
  149. UpdateRefNodeOutputMem(graph);
  150. }
  151. void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) {
  152. MS_EXCEPTION_IF_NULL(graph);
  153. // clear input parameter memory resource
  154. for (const auto &input_node : graph->inputs()) {
  155. MS_EXCEPTION_IF_NULL(input_node);
  156. AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
  157. }
  158. // clear input value node memory resource
  159. for (const auto &value_node : graph->graph_value_nodes()) {
  160. MS_EXCEPTION_IF_NULL(value_node);
  161. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  162. }
  163. for (const auto &cnode : graph->execution_order()) {
  164. MS_EXCEPTION_IF_NULL(cnode);
  165. // clear output memory resource
  166. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
  167. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  168. }
  169. // clear workspace memory resource
  170. auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
  171. MS_EXCEPTION_IF_NULL(kernel_mod);
  172. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  173. for (size_t index = 0; index < workspace_lists.size(); ++index) {
  174. AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
  175. }
  176. }
  177. }
  178. bool KernelRuntime::DumpDataEnabled() {
  179. bool ret = false;
  180. #ifdef ENABLE_DUMP_E2E
  181. DumpConfPtr dump_conf = GetDumpConf();
  182. MS_EXCEPTION_IF_NULL(dump_conf);
  183. bool dump_flag = dump_conf->dump_enable();
  184. if (!dump_flag) {
  185. return ret;
  186. }
  187. ret = true;
  188. #endif
  189. return ret;
  190. }
  191. bool KernelRuntime::DumpDataEnabledIteration() {
  192. bool ret = false;
  193. #ifdef ENABLE_DUMP_E2E
  194. if (!DumpDataEnabled()) {
  195. return ret;
  196. }
  197. DumpConfPtr dump_conf = GetDumpConf();
  198. MS_EXCEPTION_IF_NULL(dump_conf);
  199. uint32_t cur_iter = dump_conf->cur_iter() + 1;
  200. if (dump_conf->dump_iter() != 0) {
  201. if (cur_iter != dump_conf->dump_iter()) {
  202. return ret;
  203. }
  204. }
  205. ret = true;
  206. #endif
  207. return ret;
  208. }
  209. void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) {
  210. AssignStaticMemoryInput(graph);
  211. AssignStaticMemoryValueNode(graph);
  212. AssignStaticMemoryOutput(graph);
  213. }
  214. void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  215. const session::KernelGraph *graph) {
  216. MS_EXCEPTION_IF_NULL(graph);
  217. MS_EXCEPTION_IF_NULL(mem_manager_);
  218. if (input_tensors.size() != graph->inputs().size()) {
  219. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
  220. << " should be equal to graph input parameter size " << graph->inputs().size();
  221. }
  222. for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) {
  223. auto item = graph->inputs()[input_index];
  224. MS_EXCEPTION_IF_NULL(item);
  225. if (!item->isa<Parameter>()) {
  226. continue;
  227. }
  228. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  229. for (size_t index = 0; index < output_size; index++) {
  230. MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
  231. auto output_address =
  232. std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address());
  233. if (output_address != nullptr) {
  234. AnfAlgo::SetOutputAddr(output_address, index, item.get());
  235. continue;
  236. }
  237. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  238. if (output_type_id == kTypeUnknown) {
  239. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  240. }
  241. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  242. auto device_address =
  243. CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  244. MS_EXCEPTION_IF_NULL(device_address);
  245. MS_EXCEPTION_IF_NULL(mem_manager_);
  246. auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
  247. if (!ret) {
  248. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  249. }
  250. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  251. }
  252. }
  253. }
  254. void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
  255. MS_EXCEPTION_IF_NULL(kernel);
  256. MS_EXCEPTION_IF_NULL(mem_manager_);
  257. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  258. MS_EXCEPTION_IF_NULL(kernel_mod);
  259. auto output_sizes = kernel_mod->GetOutputSizeList();
  260. if (output_sizes.empty()) {
  261. return;
  262. }
  263. for (size_t i = 0; i < output_sizes.size(); ++i) {
  264. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  265. continue;
  266. }
  267. if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
  268. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  269. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  270. continue;
  271. }
  272. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  273. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  274. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  275. device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
  276. MS_EXCEPTION_IF_NULL(device_address);
  277. auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
  278. if (!ret) {
  279. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  280. }
  281. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  282. }
  283. }
  284. void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
  285. MS_EXCEPTION_IF_NULL(kernel);
  286. MS_EXCEPTION_IF_NULL(mem_manager_);
  287. if (kernel->isa<CNode>()) {
  288. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  289. MS_EXCEPTION_IF_NULL(kernel_mod);
  290. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  291. for (size_t i = 0; i < workspace_lists.size(); ++i) {
  292. auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
  293. MS_EXCEPTION_IF_NULL(device_address);
  294. auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
  295. if (!ret) {
  296. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  297. }
  298. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  299. }
  300. }
  301. }
  302. void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph) {
  303. if (pre_output_value == nullptr) {
  304. return;
  305. }
  306. std::vector<tensor::TensorPtr> pre_output_tensors;
  307. TensorValueToTensor(pre_output_value, &pre_output_tensors);
  308. MS_EXCEPTION_IF_NULL(graph);
  309. auto output_nodes = graph->outputs();
  310. if (pre_output_tensors.size() != output_nodes.size()) {
  311. MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
  312. << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
  313. }
  314. // share output address with pre output tensors
  315. for (size_t i = 0; i < output_nodes.size(); ++i) {
  316. auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
  317. if (!output_node_with_index.first->isa<CNode>()) {
  318. MS_LOG(EXCEPTION) << "The output node should be a cnode , but it is "
  319. << output_node_with_index.first->DebugString();
  320. }
  321. auto real_output_cnode = output_node_with_index.first->cast<CNodePtr>();
  322. MS_EXCEPTION_IF_NULL(real_output_cnode);
  323. MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
  324. if (pre_output_tensors[i]->device_address() == nullptr) {
  325. MS_LOG(EXCEPTION) << "The address of pre output tensor [" << i << "] is a nullptr!";
  326. }
  327. if (opt::IsNopNode(real_output_cnode)) {
  328. if (real_output_cnode->inputs().size() < 2) {
  329. MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
  330. << " should large than one!";
  331. }
  332. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
  333. output_node_with_index.second, real_output_cnode->input(1).get());
  334. } else {
  335. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
  336. output_node_with_index.second, output_node_with_index.first.get());
  337. }
  338. }
  339. }
  340. void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
  341. MS_EXCEPTION_IF_NULL(graph);
  342. MS_EXCEPTION_IF_NULL(mem_manager_);
  343. auto graph_inputs = graph->inputs();
  344. auto graph_valid_input = graph->valid_inputs();
  345. graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  346. std::vector<AnfNodePtr> need_alloc_nodes;
  347. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  348. auto item = graph_inputs[i];
  349. MS_EXCEPTION_IF_NULL(item);
  350. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  351. continue;
  352. }
  353. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  354. auto outs = AnfAlgo::GetAllOutput(item);
  355. for (auto &out : outs) {
  356. MS_EXCEPTION_IF_NULL(out);
  357. if (!out->isa<Parameter>()) {
  358. continue;
  359. }
  360. if (NodeOutputDeviceAddressExist(out, 0)) {
  361. continue;
  362. }
  363. need_alloc_nodes.push_back(out);
  364. }
  365. }
  366. if (!item->isa<Parameter>()) {
  367. continue;
  368. }
  369. if (NodeOutputDeviceAddressExist(item, 0)) {
  370. continue;
  371. }
  372. need_alloc_nodes.push_back(item);
  373. }
  374. for (auto &item : need_alloc_nodes) {
  375. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  376. for (size_t index = 0; index < output_size; index++) {
  377. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  378. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  379. if (output_type_id == kTypeUnknown) {
  380. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  381. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  382. }
  383. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  384. auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  385. if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
  386. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
  387. }
  388. AnfAlgo::SetOutputAddr(address, index, item.get());
  389. }
  390. }
  391. }
  392. void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) {
  393. MS_EXCEPTION_IF_NULL(graph);
  394. auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
  395. std::vector<session::KernelWithIndex> non_communication_op;
  396. // Assign Communicate Op Memory firstly.
  397. for (const auto &node : nodes) {
  398. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  399. MS_EXCEPTION_IF_NULL(item_with_index.first);
  400. if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
  401. continue;
  402. }
  403. if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
  404. AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
  405. } else {
  406. non_communication_op.emplace_back(item_with_index);
  407. }
  408. }
  409. for (const auto &item_with_index : non_communication_op) {
  410. AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
  411. }
  412. }
  413. void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
  414. MS_EXCEPTION_IF_NULL(graph);
  415. auto &kernels = graph->execution_order();
  416. for (auto &kernel : kernels) {
  417. MS_EXCEPTION_IF_NULL(kernel);
  418. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  419. MS_EXCEPTION_IF_NULL(kernel_mod);
  420. auto output_sizes = kernel_mod->GetOutputSizeList();
  421. if (output_sizes.empty()) {
  422. MS_LOG(INFO) << "This kernel has no output size.";
  423. continue;
  424. }
  425. for (size_t i = 0; i < output_sizes.size(); ++i) {
  426. session::AnfWithOutIndex out_pair(kernel, i);
  427. if (graph->IsInRefOutputMap(out_pair)) {
  428. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  429. MS_EXCEPTION_IF_NULL(origin_pair.first);
  430. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  431. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  432. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  433. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  434. MS_LOG(INFO) << "REF address is not same, ref node output need address update";
  435. MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  436. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  437. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  438. }
  439. }
  440. }
  441. }
  442. }
  443. void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
  444. AssignCommunicationNodeInputMem(type, node);
  445. AssignCommunicationNodeOutputMem(type, node);
  446. }
  447. void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
  448. MS_EXCEPTION_IF_NULL(node);
  449. MS_EXCEPTION_IF_NULL(mem_manager_);
  450. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  451. MS_EXCEPTION_IF_NULL(kernel_mod);
  452. auto output_sizes = kernel_mod->GetOutputSizeList();
  453. if (output_sizes.empty()) {
  454. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  455. return;
  456. }
  457. auto context_ptr = MsContext::GetInstance();
  458. MS_EXCEPTION_IF_NULL(context_ptr);
  459. size_t total_size = 0;
  460. size_t output_index = 0;
  461. std::vector<size_t> align_size_list;
  462. for (uint64_t mem_size : output_sizes) {
  463. if (AnfAlgo::OutputAddrExist(node, output_index++)) {
  464. MS_LOG(INFO) << "communication op addr exist";
  465. continue;
  466. }
  467. if (context_ptr->enable_hccl()) {
  468. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  469. }
  470. total_size += mem_size;
  471. align_size_list.emplace_back(mem_size);
  472. }
  473. if (type == kReuseDynamicMem) {
  474. // reuse communication op's all outputs' memory
  475. type = kReuseDynamicCommMem;
  476. }
  477. uint8_t *output_ptr = nullptr;
  478. for (size_t j = 0; j < align_size_list.size(); ++j) {
  479. std::string output_format = AnfAlgo::GetOutputFormat(node, j);
  480. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
  481. auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
  482. MS_EXCEPTION_IF_NULL(address);
  483. if (output_ptr == nullptr) {
  484. output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address);
  485. MS_EXCEPTION_IF_NULL(output_ptr);
  486. } else {
  487. address->set_ptr(output_ptr);
  488. }
  489. AnfAlgo::SetOutputAddr(address, j, node.get());
  490. output_ptr += align_size_list[j];
  491. }
  492. }
  493. DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
  494. MS_EXCEPTION_IF_NULL(anf_node);
  495. auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
  496. auto output_sizes = kernel_mod->GetOutputSizeList();
  497. if (output_sizes.size() <= index) {
  498. MS_LOG(EXCEPTION) << "Previous node output size < node index";
  499. }
  500. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
  501. auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
  502. auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
  503. AnfAlgo::SetOutputAddr(address, index, anf_node.get());
  504. return address;
  505. }
  506. void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
  507. auto context_ptr = MsContext::GetInstance();
  508. MS_EXCEPTION_IF_NULL(context_ptr);
  509. MS_EXCEPTION_IF_NULL(node);
  510. MS_EXCEPTION_IF_NULL(mem_manager_);
  511. size_t total_size = 0;
  512. std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
  513. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
  514. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
  515. auto input_node = input_node_with_index.first;
  516. DeviceAddressPtr address = nullptr;
  517. if (input_node->isa<CNode>()) {
  518. address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
  519. } else {
  520. MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
  521. }
  522. MS_EXCEPTION_IF_NULL(address);
  523. auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
  524. total_size += mem_size;
  525. addr_size.emplace_back(address, mem_size);
  526. }
  527. if (addr_size.empty()) {
  528. return;
  529. }
  530. uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first);
  531. for (const auto &iter : addr_size) {
  532. MS_EXCEPTION_IF_NULL(iter.first);
  533. iter.first->set_ptr(input_ptr);
  534. input_ptr += iter.second;
  535. }
  536. }
  537. void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
  538. MS_EXCEPTION_IF_NULL(node);
  539. MS_EXCEPTION_IF_NULL(mem_manager_);
  540. if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) {
  541. MS_LOG(INFO) << "GetNext disable mem_reuse";
  542. type = kDynamicMem;
  543. }
  544. if (node->isa<CNode>()) {
  545. bool independent = AnfAlgo::IsIndependentNode(node->cast<CNodePtr>());
  546. if (independent && type == kReuseDynamicMem) {
  547. MS_LOG(INFO) << "Independent disable mem_reuse";
  548. type = kDynamicMem;
  549. }
  550. }
  551. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  552. MS_EXCEPTION_IF_NULL(kernel_mod);
  553. auto output_sizes = kernel_mod->GetOutputSizeList();
  554. if (output_sizes.empty()) {
  555. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  556. return;
  557. }
  558. for (size_t i = 0; i < output_sizes.size(); ++i) {
  559. if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
  560. continue;
  561. }
  562. if (NodeOutputDeviceAddressExist(node, i)) {
  563. MS_LOG(INFO) << "Already malloc index:" << i;
  564. continue;
  565. }
  566. std::string output_format = AnfAlgo::GetOutputFormat(node, i);
  567. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
  568. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  569. MS_EXCEPTION_IF_NULL(device_address);
  570. uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address);
  571. MS_EXCEPTION_IF_NULL(ptr);
  572. device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
  573. AnfAlgo::SetOutputAddr(device_address, i, node.get());
  574. }
  575. }
  576. void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
  577. size_t output_idx) {
  578. MS_EXCEPTION_IF_NULL(value_node);
  579. MS_EXCEPTION_IF_NULL(node_value);
  580. MS_EXCEPTION_IF_NULL(mem_manager_);
  581. auto ms_context = MsContext::GetInstance();
  582. MS_EXCEPTION_IF_NULL(ms_context);
  583. std::vector<tensor::TensorPtr> tensors;
  584. TensorValueToTensor(node_value, &tensors);
  585. for (const auto &tensor : tensors) {
  586. if (tensor == nullptr) {
  587. MS_LOG(WARNING) << "Tensor is null";
  588. return;
  589. }
  590. if (tensor->device_address() != nullptr) {
  591. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  592. value_node.get());
  593. continue;
  594. }
  595. size_t tensor_size = tensor->data().nbytes();
  596. auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
  597. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  598. if (output_type_id == kTypeUnknown) {
  599. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  600. }
  601. auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  602. DeviceAddressPtr address = nullptr;
  603. address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
  604. MS_EXCEPTION_IF_NULL(address);
  605. if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) {
  606. MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
  607. } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
  608. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
  609. }
  610. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  611. if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
  612. tensor->data_c())) {
  613. MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
  614. << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
  615. << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  616. }
  617. }
  618. }
  619. void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
  620. MS_EXCEPTION_IF_NULL(graph);
  621. MS_EXCEPTION_IF_NULL(mem_manager_);
  622. auto ms_context = MsContext::GetInstance();
  623. MS_EXCEPTION_IF_NULL(ms_context);
  624. for (auto &value_node : graph->graph_value_nodes()) {
  625. MS_EXCEPTION_IF_NULL(value_node);
  626. if (NodeOutputDeviceAddressExist(value_node, 0)) {
  627. MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
  628. continue;
  629. }
  630. auto &node_value = value_node->value();
  631. MS_EXCEPTION_IF_NULL(node_value);
  632. if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
  633. AssignValueNodeTensor(value_node, node_value, 0);
  634. } else if (node_value->isa<StringImm>()) {
  635. auto value = GetValue<std::string>(node_value);
  636. size_t tensor_size = value.size();
  637. DeviceAddressPtr address = nullptr;
  638. address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  639. MS_EXCEPTION_IF_NULL(address);
  640. if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
  641. MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
  642. } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
  643. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
  644. }
  645. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  646. std::vector<int> shape = {1, SizeToInt(tensor_size)};
  647. if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  648. MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
  649. }
  650. }
  651. }
  652. }
  653. void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
  654. MS_EXCEPTION_IF_NULL(graph);
  655. MS_EXCEPTION_IF_NULL(mem_manager_);
  656. auto context_ptr = MsContext::GetInstance();
  657. MS_EXCEPTION_IF_NULL(context_ptr);
  658. bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
  659. auto mem_type = kDynamicMem;
  660. if (is_enable_mem_reuse) {
  661. mem_manager_->MallocReusedDynamicMem(graph);
  662. mem_type = kReuseDynamicMem;
  663. }
  664. auto &execution_nodes = graph->execution_order();
  665. std::vector<CNodePtr> compute_nodes;
  666. // communication nodes first
  667. for (auto &node : execution_nodes) {
  668. if (AnfAlgo::IsCommunicationOp(node)) {
  669. // skip if the memory is already alocated
  670. AssignCommunicationNodeMem(mem_type, node);
  671. } else {
  672. compute_nodes.emplace_back(node);
  673. }
  674. }
  675. // then compute nodes
  676. for (auto &node : compute_nodes) {
  677. AssignNodeOutputMem(mem_type, node, kGetAllOuts);
  678. AssignWorkSpaceMem(mem_type, node);
  679. }
  680. }
  681. void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
  682. MS_EXCEPTION_IF_NULL(node);
  683. MS_EXCEPTION_IF_NULL(mem_manager_);
  684. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  685. MS_EXCEPTION_IF_NULL(kernel_mod);
  686. size_t index = 0;
  687. for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
  688. auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
  689. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
  690. index++;
  691. }
  692. }
  693. void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
  694. AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
  695. AddressPtrList *kernel_outputs) {
  696. MS_EXCEPTION_IF_NULL(kernel);
  697. MS_EXCEPTION_IF_NULL(kernel_inputs);
  698. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  699. MS_EXCEPTION_IF_NULL(kernel_outputs);
  700. auto cnode = kernel->cast<CNodePtr>();
  701. MS_EXCEPTION_IF_NULL(cnode);
  702. if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
  703. return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
  704. }
  705. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  706. auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
  707. auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
  708. MS_EXCEPTION_IF_NULL(device_address);
  709. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  710. MS_EXCEPTION_IF_NULL(input);
  711. input->addr = device_address->ptr_;
  712. MS_EXCEPTION_IF_NULL(input->addr);
  713. input->size = device_address->size_;
  714. kernel_inputs->emplace_back(input);
  715. }
  716. for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
  717. auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
  718. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  719. MS_EXCEPTION_IF_NULL(output);
  720. output->addr = device_address->ptr_;
  721. MS_EXCEPTION_IF_NULL(output->addr);
  722. output->size = device_address->size_;
  723. kernel_outputs->emplace_back(output);
  724. }
  725. for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
  726. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  727. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  728. MS_EXCEPTION_IF_NULL(workspace);
  729. workspace->addr = device_address->ptr_;
  730. MS_EXCEPTION_IF_NULL(workspace->addr);
  731. workspace->size = device_address->size_;
  732. kernel_workspaces->emplace_back(workspace);
  733. }
  734. }
  735. void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
  736. if (cnode->inputs().size() != 2) {
  737. MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
  738. }
  739. MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
  740. auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
  741. // set clean output address
  742. if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
  743. auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
  744. for (auto index : clean_output_indexs) {
  745. auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
  746. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  747. MS_EXCEPTION_IF_NULL(input);
  748. input->addr = device_address->ptr_;
  749. MS_EXCEPTION_IF_NULL(input->addr);
  750. input->size = device_address->size_;
  751. kernel_inputs->emplace_back(input);
  752. }
  753. MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
  754. }
  755. // set clean workspace address
  756. if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
  757. auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
  758. for (const auto &index : clean_workspaces_indexs) {
  759. auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
  760. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  761. MS_EXCEPTION_IF_NULL(workspace);
  762. workspace->addr = device_address->ptr_;
  763. MS_EXCEPTION_IF_NULL(workspace->addr);
  764. workspace->size = device_address->size_;
  765. kernel_inputs->emplace_back(workspace);
  766. }
  767. }
  768. }
  769. bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
  770. auto &kernels = graph.execution_order();
  771. for (const auto &kernel : kernels) {
  772. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  773. MS_EXCEPTION_IF_NULL(kernel_mod);
  774. AddressPtrList kernel_inputs;
  775. AddressPtrList kernel_workspaces;
  776. AddressPtrList kernel_outputs;
  777. GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  778. auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  779. if (!ret) {
  780. MS_LOG(ERROR) << "Launch kernel failed.";
  781. return false;
  782. }
  783. }
  784. return true;
  785. }
  786. bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
  787. MS_EXCEPTION_IF_NULL(graph);
  788. if (!LaunchKernelMod(*graph)) {
  789. MS_LOG(ERROR) << "LaunchKernelMod failed!";
  790. return false;
  791. }
  792. return true;
  793. }
  794. void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
  795. MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
  796. }
  797. bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr,
  798. const AddressPtrList &kernel_inputs,
  799. const AddressPtrList &kernel_outputs,
  800. const AddressPtrList &kernel_workspaces) const {
  801. MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
  802. auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  803. if (!ret) {
  804. MS_LOG(ERROR) << "Launch kernel failed.";
  805. return false;
  806. }
  807. return true;
  808. }
  809. DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type) {
  810. auto device_address = CreateDeviceAddress(nullptr, size, format, type);
  811. MS_EXCEPTION_IF_NULL(device_address);
  812. MS_EXCEPTION_IF_NULL(mem_manager_);
  813. auto base_ptr = mem_manager_->MallocMem(kStaticMem, size, device_address);
  814. MS_EXCEPTION_IF_NULL(base_ptr);
  815. return device_address;
  816. }
  817. #ifdef ENABLE_DUMP_E2E
  818. bool KernelRuntime::SetDumpConf() {
  819. dump_conf_ptr_ = std::make_shared<Dump>();
  820. MS_EXCEPTION_IF_NULL(dump_conf_ptr_);
  821. bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile();
  822. return ret;
  823. }
  824. DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; }
  825. #endif
  826. } // namespace device
  827. } // namespace mindspore