You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime.cc 40 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/kernel_runtime.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <utility>
  20. #include <vector>
  21. #include "backend/optimizer/common/helper.h"
  22. #include "backend/session/anf_runtime_algorithm.h"
  23. #include "backend/session/kernel_graph.h"
  24. #include "common/trans.h"
  25. #include "debug/data_dump/dump_json_parser.h"
  26. #include "frontend/operator/ops.h"
  27. #include "ir/value.h"
  28. #include "utils/ms_context.h"
  29. #include "utils/ms_utils.h"
  30. #include "utils/shape_utils.h"
  31. #include "utils/utils.h"
  32. using mindspore::kernel::Address;
  33. using mindspore::kernel::AddressPtr;
  34. namespace mindspore {
  35. namespace device {
  36. KernelRuntime::~KernelRuntime() {}
  37. bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; }
  38. bool KernelRuntime::LoadData(session::KernelGraph *graph) { return false; }
  39. bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
  40. MS_EXCEPTION_IF_NULL(kernel);
  41. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  42. return true;
  43. }
  44. return false;
  45. }
  46. size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
  47. MS_EXCEPTION_IF_NULL(node);
  48. if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
  49. MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
  50. << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
  51. }
  52. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  53. if (output_type_id == kTypeUnknown) {
  54. output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  55. }
  56. size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
  57. std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
  58. auto format = AnfAlgo::GetOutputFormat(node, output_index);
  59. if (shape.empty() && format != kOpFormat_DEFAULT) {
  60. shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
  61. shape = trans::TransShapeToDevice(shape, format);
  62. }
  63. // scalar's output shape is a empty vector
  64. size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
  65. return tensor_size;
  66. }
  67. void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
  68. auto context_ptr = MsContext::GetInstance();
  69. MS_EXCEPTION_IF_NULL(context_ptr);
  70. MS_EXCEPTION_IF_NULL(mem_manager_);
  71. mem_manager_->ResetDynamicMemory();
  72. AssignStaticMemory(graph);
  73. AssignDynamicMemory(graph);
  74. UpdateRefNodeOutputMem(graph);
  75. }
  76. void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value,
  77. const std::vector<tensor::TensorPtr> &input_tensors,
  78. session::KernelGraph *graph) {
  79. MS_EXCEPTION_IF_NULL(graph);
  80. MS_EXCEPTION_IF_NULL(mem_manager_);
  81. mem_manager_->ResetDynamicMemory();
  82. RunOpAssignInputMemory(input_tensors, graph);
  83. AssignStaticMemoryValueNode(graph);
  84. RunOpAssignOutputNodeMemory(pre_output_value, graph);
  85. for (const auto &cnode : graph->execution_order()) {
  86. RunOpAssignOutputMemory(cnode);
  87. RunOpAssignWorkSpaceMemory(cnode);
  88. }
  89. UpdateRefNodeOutputMem(graph);
  90. }
  91. void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) {
  92. MS_EXCEPTION_IF_NULL(graph);
  93. // clear input parameter memory resource
  94. for (const auto &input_node : graph->inputs()) {
  95. MS_EXCEPTION_IF_NULL(input_node);
  96. AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
  97. }
  98. // clear input value node memory resource
  99. for (const auto &value_node : graph->graph_value_nodes()) {
  100. MS_EXCEPTION_IF_NULL(value_node);
  101. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  102. }
  103. for (const auto &cnode : graph->execution_order()) {
  104. MS_EXCEPTION_IF_NULL(cnode);
  105. // clear output memory resource
  106. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
  107. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  108. }
  109. // clear workspace memory resource
  110. auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
  111. MS_EXCEPTION_IF_NULL(kernel_mod);
  112. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  113. for (size_t index = 0; index < workspace_lists.size(); ++index) {
  114. AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
  115. }
  116. }
  117. }
  118. bool KernelRuntime::DumpDataEnabled() {
  119. auto &dump_json_parser = DumpJsonParser::GetInstance();
  120. return dump_json_parser.e2e_dump_enabled();
  121. }
  122. bool KernelRuntime::DumpDataEnabledIteration() {
  123. auto &dump_json_parser = DumpJsonParser::GetInstance();
  124. if (!dump_json_parser.e2e_dump_enabled()) {
  125. return false;
  126. }
  127. auto cur_iter = dump_json_parser.cur_dump_iter() + 1;
  128. if (dump_json_parser.iteration() != 0) {
  129. return cur_iter == dump_json_parser.iteration();
  130. }
  131. return true;
  132. }
  133. void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) {
  134. AssignStaticMemoryInput(graph);
  135. AssignStaticMemoryValueNode(graph);
  136. AssignStaticMemoryOutput(graph);
  137. }
  138. void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
  139. const session::KernelGraph *graph) {
  140. MS_EXCEPTION_IF_NULL(graph);
  141. MS_EXCEPTION_IF_NULL(mem_manager_);
  142. if (input_tensors.size() != graph->inputs().size()) {
  143. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
  144. << " should be equal to graph input parameter size " << graph->inputs().size();
  145. }
  146. for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) {
  147. auto item = graph->inputs()[input_index];
  148. MS_EXCEPTION_IF_NULL(item);
  149. if (!item->isa<Parameter>()) {
  150. continue;
  151. }
  152. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  153. for (size_t index = 0; index < output_size; index++) {
  154. MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
  155. auto output_address =
  156. std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address());
  157. if (output_address != nullptr) {
  158. AnfAlgo::SetOutputAddr(output_address, index, item.get());
  159. continue;
  160. }
  161. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  162. if (output_type_id == kTypeUnknown) {
  163. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  164. }
  165. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  166. auto device_address =
  167. CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  168. MS_EXCEPTION_IF_NULL(device_address);
  169. MS_EXCEPTION_IF_NULL(mem_manager_);
  170. auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
  171. if (!ret) {
  172. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  173. }
  174. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  175. }
  176. }
  177. }
  178. void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
  179. MS_EXCEPTION_IF_NULL(kernel);
  180. MS_EXCEPTION_IF_NULL(mem_manager_);
  181. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  182. MS_EXCEPTION_IF_NULL(kernel_mod);
  183. auto output_sizes = kernel_mod->GetOutputSizeList();
  184. if (output_sizes.empty()) {
  185. return;
  186. }
  187. for (size_t i = 0; i < output_sizes.size(); ++i) {
  188. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  189. continue;
  190. }
  191. if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
  192. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  193. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  194. continue;
  195. }
  196. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  197. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  198. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  199. device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
  200. MS_EXCEPTION_IF_NULL(device_address);
  201. auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
  202. if (!ret) {
  203. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  204. }
  205. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  206. }
  207. }
  208. void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
  209. MS_EXCEPTION_IF_NULL(kernel);
  210. MS_EXCEPTION_IF_NULL(mem_manager_);
  211. if (kernel->isa<CNode>()) {
  212. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  213. MS_EXCEPTION_IF_NULL(kernel_mod);
  214. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  215. for (size_t i = 0; i < workspace_lists.size(); ++i) {
  216. auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
  217. MS_EXCEPTION_IF_NULL(device_address);
  218. auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
  219. if (!ret) {
  220. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  221. }
  222. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  223. }
  224. }
  225. }
  226. void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph) {
  227. if (pre_output_value == nullptr) {
  228. return;
  229. }
  230. std::vector<tensor::TensorPtr> pre_output_tensors;
  231. TensorValueToTensor(pre_output_value, &pre_output_tensors);
  232. MS_EXCEPTION_IF_NULL(graph);
  233. auto output_nodes = graph->outputs();
  234. if (pre_output_tensors.size() != output_nodes.size()) {
  235. MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
  236. << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
  237. }
  238. // share output address with pre output tensors
  239. for (size_t i = 0; i < output_nodes.size(); ++i) {
  240. auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
  241. if (!output_node_with_index.first->isa<CNode>()) {
  242. if (output_node_with_index.first->isa<Parameter>()) {
  243. auto param = output_node_with_index.first->cast<ParameterPtr>();
  244. if (!param->has_default()) {
  245. MS_LOG(EXCEPTION) << "The output parameter should be real parameter!";
  246. }
  247. }
  248. continue;
  249. }
  250. auto real_output_cnode = output_node_with_index.first->cast<CNodePtr>();
  251. MS_EXCEPTION_IF_NULL(real_output_cnode);
  252. MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
  253. if (pre_output_tensors[i]->device_address() == nullptr) {
  254. MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
  255. continue;
  256. }
  257. if (opt::IsNopNode(real_output_cnode)) {
  258. if (real_output_cnode->inputs().size() < 2) {
  259. MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
  260. << " should large than one!";
  261. }
  262. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
  263. output_node_with_index.second, real_output_cnode->input(1).get());
  264. } else {
  265. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
  266. output_node_with_index.second, output_node_with_index.first.get());
  267. }
  268. }
  269. }
  270. void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
  271. MS_EXCEPTION_IF_NULL(graph);
  272. MS_EXCEPTION_IF_NULL(mem_manager_);
  273. MS_LOG(INFO) << "AssignStaticMemoryInput start";
  274. auto graph_inputs = graph->inputs();
  275. auto graph_valid_input = graph->valid_inputs();
  276. graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  277. std::vector<AnfNodePtr> need_alloc_nodes;
  278. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  279. auto item = graph_inputs[i];
  280. MS_EXCEPTION_IF_NULL(item);
  281. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  282. continue;
  283. }
  284. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  285. auto outs = AnfAlgo::GetAllOutput(item);
  286. for (auto &out : outs) {
  287. MS_EXCEPTION_IF_NULL(out);
  288. if (!out->isa<Parameter>()) {
  289. continue;
  290. }
  291. if (NodeOutputDeviceAddressExist(out, 0)) {
  292. continue;
  293. }
  294. need_alloc_nodes.push_back(out);
  295. }
  296. }
  297. if (!item->isa<Parameter>()) {
  298. continue;
  299. }
  300. if (NodeOutputDeviceAddressExist(item, 0)) {
  301. continue;
  302. }
  303. need_alloc_nodes.push_back(item);
  304. }
  305. for (auto &item : need_alloc_nodes) {
  306. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  307. for (size_t index = 0; index < output_size; index++) {
  308. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  309. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  310. if (output_type_id == kTypeUnknown) {
  311. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  312. continue;
  313. }
  314. auto tensor_size = CountNodeDeviceMemorySize(item, index);
  315. auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
  316. MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope();
  317. if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
  318. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
  319. }
  320. MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
  321. << " index: " << index << " size: " << tensor_size;
  322. AnfAlgo::SetOutputAddr(address, index, item.get());
  323. }
  324. }
  325. MS_LOG(INFO) << "AssignStaticMemoryInput end";
  326. }
  327. void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
  328. MS_EXCEPTION_IF_NULL(graph);
  329. MS_LOG(INFO) << "AssignStaticMemoryOutput start";
  330. auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
  331. std::vector<session::KernelWithIndex> non_communication_op;
  332. // Assign Communicate Op Memory firstly.
  333. for (const auto &node : nodes) {
  334. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  335. MS_EXCEPTION_IF_NULL(item_with_index.first);
  336. if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
  337. continue;
  338. }
  339. if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
  340. AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
  341. } else {
  342. non_communication_op.emplace_back(item_with_index);
  343. }
  344. }
  345. for (const auto &item_with_index : non_communication_op) {
  346. MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope();
  347. AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
  348. }
  349. MS_LOG(INFO) << "AssignStaticMemoryOutput end";
  350. }
  351. void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
  352. MS_EXCEPTION_IF_NULL(graph);
  353. auto &kernels = graph->execution_order();
  354. for (auto &kernel : kernels) {
  355. MS_EXCEPTION_IF_NULL(kernel);
  356. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  357. MS_EXCEPTION_IF_NULL(kernel_mod);
  358. auto output_sizes = kernel_mod->GetOutputSizeList();
  359. if (output_sizes.empty()) {
  360. MS_LOG(INFO) << "This kernel has no output size.";
  361. continue;
  362. }
  363. for (size_t i = 0; i < output_sizes.size(); ++i) {
  364. session::AnfWithOutIndex out_pair(kernel, i);
  365. if (graph->IsInRefOutputMap(out_pair)) {
  366. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  367. MS_EXCEPTION_IF_NULL(origin_pair.first);
  368. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  369. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  370. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  371. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  372. MS_LOG(INFO) << "REF address is not same, ref node output need address update";
  373. MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  374. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  375. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  376. }
  377. }
  378. }
  379. }
  380. }
  381. void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
  382. AssignCommunicationNodeInputMem(type, node);
  383. AssignCommunicationNodeOutputMem(type, node);
  384. }
  385. void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
  386. MS_EXCEPTION_IF_NULL(node);
  387. MS_EXCEPTION_IF_NULL(mem_manager_);
  388. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  389. MS_EXCEPTION_IF_NULL(kernel_mod);
  390. auto output_sizes = kernel_mod->GetOutputSizeList();
  391. if (output_sizes.empty()) {
  392. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  393. return;
  394. }
  395. auto context_ptr = MsContext::GetInstance();
  396. MS_EXCEPTION_IF_NULL(context_ptr);
  397. size_t total_size = 0;
  398. size_t output_index = 0;
  399. std::vector<size_t> align_size_list;
  400. for (uint64_t mem_size : output_sizes) {
  401. if (AnfAlgo::OutputAddrExist(node, output_index++)) {
  402. MS_LOG(INFO) << "communication op addr exist";
  403. continue;
  404. }
  405. if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
  406. mem_size = mem_manager_->GetCommonAlignSize(mem_size);
  407. }
  408. total_size += mem_size;
  409. align_size_list.emplace_back(mem_size);
  410. }
  411. if (type == kReuseDynamicMem) {
  412. // reuse communication op's all outputs' memory
  413. type = kReuseDynamicCommMem;
  414. }
  415. if (type == kReuseDynamicCommMem || type == kSomasReuseDynamicMem) {
  416. bool not_reuse = KernelMemNotReuse(node);
  417. if (not_reuse) {
  418. type = kDynamicMem;
  419. MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
  420. }
  421. }
  422. uint8_t *output_ptr = nullptr;
  423. for (size_t j = 0; j < align_size_list.size(); ++j) {
  424. std::string output_format = AnfAlgo::GetOutputFormat(node, j);
  425. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
  426. auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
  427. MS_EXCEPTION_IF_NULL(address);
  428. if (output_ptr == nullptr) {
  429. output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
  430. MS_EXCEPTION_IF_NULL(output_ptr);
  431. } else {
  432. address->set_ptr(output_ptr);
  433. }
  434. AnfAlgo::SetOutputAddr(address, j, node.get());
  435. output_ptr += align_size_list[j];
  436. }
  437. }
  438. bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; }
  439. DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
  440. MS_EXCEPTION_IF_NULL(anf_node);
  441. if (!anf_node->isa<CNode>()) {
  442. MS_LOG(EXCEPTION) << "anf_node should be a cnode";
  443. }
  444. auto cnode = anf_node->cast<CNodePtr>();
  445. if (opt::IsNopNode(cnode)) {
  446. const size_t kNopNodeInputSize = 2;
  447. if (cnode->size() != kNopNodeInputSize) {
  448. MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size();
  449. }
  450. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
  451. return PreAssignCNodeMemory(input_node_with_index.first, input_node_with_index.second);
  452. }
  453. auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
  454. MS_EXCEPTION_IF_NULL(kernel_mod);
  455. auto output_sizes = kernel_mod->GetOutputSizeList();
  456. if (output_sizes.size() <= index) {
  457. MS_LOG(EXCEPTION) << "Previous node output size < node index";
  458. }
  459. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
  460. auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
  461. auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
  462. AnfAlgo::SetOutputAddr(address, index, anf_node.get());
  463. return address;
  464. }
  465. void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
  466. auto context_ptr = MsContext::GetInstance();
  467. MS_EXCEPTION_IF_NULL(context_ptr);
  468. MS_EXCEPTION_IF_NULL(node);
  469. MS_EXCEPTION_IF_NULL(mem_manager_);
  470. size_t total_size = 0;
  471. std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
  472. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
  473. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
  474. auto input_node = input_node_with_index.first;
  475. DeviceAddressPtr address = nullptr;
  476. if (input_node->isa<CNode>()) {
  477. address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
  478. } else {
  479. MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
  480. }
  481. MS_EXCEPTION_IF_NULL(address);
  482. auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
  483. total_size += mem_size;
  484. addr_size.emplace_back(address, mem_size);
  485. }
  486. if (addr_size.empty()) {
  487. return;
  488. }
  489. if (type == kReuseDynamicMem || type == kSomasReuseDynamicMem) {
  490. bool not_reuse = KernelMemNotReuse(node);
  491. if (not_reuse) {
  492. type = kDynamicMem;
  493. MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
  494. }
  495. }
  496. auto cnode = node->cast<CNodePtr>();
  497. MS_EXCEPTION_IF_NULL(cnode);
  498. if (cnode->inputs().size() < 2) {
  499. // communication node's input should contain itself and at least on input
  500. MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
  501. return;
  502. }
  503. auto first_input_node = cnode->input(1);
  504. auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
  505. uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
  506. addr_size[0].first, true);
  507. for (const auto &iter : addr_size) {
  508. MS_EXCEPTION_IF_NULL(iter.first);
  509. iter.first->set_ptr(input_ptr);
  510. input_ptr += iter.second;
  511. }
  512. }
  513. void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
  514. MS_EXCEPTION_IF_NULL(node);
  515. MS_EXCEPTION_IF_NULL(mem_manager_);
  516. if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) {
  517. MS_LOG(INFO) << "GetNext disable mem_reuse";
  518. type = kDynamicMem;
  519. }
  520. if (node->isa<CNode>()) {
  521. bool independent = AnfAlgo::IsIndependentNode(node->cast<CNodePtr>());
  522. if (independent && (type == kReuseDynamicMem)) {
  523. MS_LOG(INFO) << "Independent node " << node->fullname_with_scope() << " disable memory reuse";
  524. type = kDynamicMem;
  525. }
  526. }
  527. if (type == kReuseDynamicMem || type == kSomasReuseDynamicMem) {
  528. bool not_reuse = KernelMemNotReuse(node);
  529. if (not_reuse) {
  530. type = kDynamicMem;
  531. MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
  532. }
  533. }
  534. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  535. MS_EXCEPTION_IF_NULL(kernel_mod);
  536. auto output_sizes = kernel_mod->GetOutputSizeList();
  537. if (output_sizes.empty()) {
  538. MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
  539. return;
  540. }
  541. for (size_t i = 0; i < output_sizes.size(); ++i) {
  542. if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
  543. continue;
  544. }
  545. if (NodeOutputDeviceAddressExist(node, i)) {
  546. MS_LOG(INFO) << "Already malloc index:" << i;
  547. continue;
  548. }
  549. MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
  550. std::string output_format = AnfAlgo::GetOutputFormat(node, i);
  551. auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
  552. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  553. MS_EXCEPTION_IF_NULL(device_address);
  554. uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
  555. MS_EXCEPTION_IF_NULL(ptr);
  556. device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
  557. AnfAlgo::SetOutputAddr(device_address, i, node.get());
  558. }
  559. }
  560. void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
  561. size_t output_idx) {
  562. MS_EXCEPTION_IF_NULL(value_node);
  563. MS_EXCEPTION_IF_NULL(node_value);
  564. MS_EXCEPTION_IF_NULL(mem_manager_);
  565. auto ms_context = MsContext::GetInstance();
  566. MS_EXCEPTION_IF_NULL(ms_context);
  567. std::vector<tensor::TensorPtr> tensors;
  568. TensorValueToTensor(node_value, &tensors);
  569. for (const auto &tensor : tensors) {
  570. if (tensor == nullptr) {
  571. MS_LOG(WARNING) << "Tensor is null";
  572. return;
  573. }
  574. if (tensor->device_address() != nullptr) {
  575. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  576. value_node.get());
  577. continue;
  578. }
  579. size_t tensor_size = tensor->data().nbytes();
  580. auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
  581. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  582. if (output_type_id == kTypeUnknown) {
  583. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  584. }
  585. auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  586. DeviceAddressPtr address = nullptr;
  587. address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
  588. MS_EXCEPTION_IF_NULL(address);
  589. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
  590. !mem_manager_->MallocMemFromMemPool(address, node_size)) {
  591. MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
  592. } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
  593. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
  594. }
  595. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  596. if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
  597. tensor->data_c())) {
  598. MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
  599. << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
  600. << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  601. }
  602. }
  603. }
  604. void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
  605. MS_EXCEPTION_IF_NULL(graph);
  606. MS_EXCEPTION_IF_NULL(mem_manager_);
  607. MS_LOG(INFO) << "AssignStaticMemoryValueNode start";
  608. auto ms_context = MsContext::GetInstance();
  609. MS_EXCEPTION_IF_NULL(ms_context);
  610. for (auto &value_node : graph->graph_value_nodes()) {
  611. MS_EXCEPTION_IF_NULL(value_node);
  612. if (NodeOutputDeviceAddressExist(value_node, 0)) {
  613. MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
  614. continue;
  615. }
  616. auto &node_value = value_node->value();
  617. MS_EXCEPTION_IF_NULL(node_value);
  618. MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
  619. if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
  620. AssignValueNodeTensor(value_node, node_value, 0);
  621. } else if (node_value->isa<StringImm>()) {
  622. auto value = GetValue<std::string>(node_value);
  623. size_t tensor_size = value.size();
  624. DeviceAddressPtr address = nullptr;
  625. address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  626. MS_EXCEPTION_IF_NULL(address);
  627. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
  628. !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
  629. MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
  630. } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
  631. MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
  632. }
  633. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  634. ShapeVector shape = {1, SizeToLong(tensor_size)};
  635. if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  636. MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
  637. }
  638. }
  639. }
  640. MS_LOG(INFO) << "AssignStaticMemoryValueNode end";
  641. }
  642. void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
  643. MS_EXCEPTION_IF_NULL(graph);
  644. MS_EXCEPTION_IF_NULL(mem_manager_);
  645. auto context_ptr = MsContext::GetInstance();
  646. MS_EXCEPTION_IF_NULL(context_ptr);
  647. bool is_enable_mem_reuse = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE);
  648. auto mem_type = kDynamicMem;
  649. auto &dump_json_parser = DumpJsonParser::GetInstance();
  650. if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
  651. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, false);
  652. is_enable_mem_reuse = false;
  653. MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
  654. }
  655. if (is_enable_mem_reuse) {
  656. MS_LOG(INFO) << "Memory Reuse is enable...";
  657. #ifdef MEM_REUSE_DEBUG
  658. mem_manager_->MallocReusedDynamicMem(graph);
  659. mem_type = kReuseDynamicMem;
  660. #else
  661. mem_manager_->MallocSomasDynamicMem(graph);
  662. mem_type = kSomasReuseDynamicMem;
  663. #endif
  664. } else {
  665. MS_LOG(INFO) << "Memory Reuse is disable...";
  666. }
  667. auto &execution_nodes = graph->execution_order();
  668. std::vector<CNodePtr> compute_nodes;
  669. // communication nodes first
  670. for (auto &node : execution_nodes) {
  671. if (AnfAlgo::IsCommunicationOp(node)) {
  672. // skip if the memory is already allocated
  673. AssignCommunicationNodeMem(mem_type, node);
  674. } else {
  675. compute_nodes.emplace_back(node);
  676. }
  677. }
  678. // then compute nodes
  679. for (auto &node : compute_nodes) {
  680. AssignNodeOutputMem(mem_type, node, kGetAllOuts);
  681. AssignWorkSpaceMem(mem_type, node);
  682. }
  683. }
  684. void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
  685. MS_EXCEPTION_IF_NULL(node);
  686. MS_EXCEPTION_IF_NULL(mem_manager_);
  687. auto kernel_mod = AnfAlgo::GetKernelMod(node);
  688. MS_EXCEPTION_IF_NULL(kernel_mod);
  689. size_t index = 0;
  690. for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
  691. auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
  692. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
  693. index++;
  694. }
  695. }
  696. void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
  697. AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
  698. AddressPtrList *kernel_outputs) {
  699. MS_EXCEPTION_IF_NULL(kernel);
  700. MS_EXCEPTION_IF_NULL(kernel_inputs);
  701. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  702. MS_EXCEPTION_IF_NULL(kernel_outputs);
  703. auto cnode = kernel->cast<CNodePtr>();
  704. MS_EXCEPTION_IF_NULL(cnode);
  705. if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
  706. return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
  707. }
  708. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  709. auto op_name = AnfAlgo::GetCNodeName(cnode);
  710. constexpr auto none_placeholder_index = 3;
  711. if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
  712. continue;
  713. }
  714. if (op_name == kDynamicGRUV2OpName) {
  715. auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
  716. auto item = std::find(none_index.begin(), none_index.end(), i);
  717. if (item != none_index.end()) {
  718. continue;
  719. }
  720. }
  721. auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
  722. auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
  723. MS_EXCEPTION_IF_NULL(device_address);
  724. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  725. MS_EXCEPTION_IF_NULL(input);
  726. input->addr = device_address->ptr_;
  727. MS_EXCEPTION_IF_NULL(input->addr);
  728. input->size = device_address->size_;
  729. kernel_inputs->emplace_back(input);
  730. }
  731. for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
  732. auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
  733. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  734. MS_EXCEPTION_IF_NULL(output);
  735. output->addr = device_address->ptr_;
  736. MS_EXCEPTION_IF_NULL(output->addr);
  737. output->size = device_address->size_;
  738. kernel_outputs->emplace_back(output);
  739. }
  740. for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
  741. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  742. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  743. MS_EXCEPTION_IF_NULL(workspace);
  744. workspace->addr = device_address->ptr_;
  745. MS_EXCEPTION_IF_NULL(workspace->addr);
  746. workspace->size = device_address->size_;
  747. kernel_workspaces->emplace_back(workspace);
  748. }
  749. }
  750. void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
  751. if (cnode->inputs().size() != 2) {
  752. MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
  753. }
  754. MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
  755. auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
  756. // set clean output address
  757. if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
  758. auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
  759. for (auto index : clean_output_indexes) {
  760. auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
  761. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  762. MS_EXCEPTION_IF_NULL(input);
  763. input->addr = device_address->ptr_;
  764. MS_EXCEPTION_IF_NULL(input->addr);
  765. input->size = device_address->size_;
  766. kernel_inputs->emplace_back(input);
  767. }
  768. MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
  769. }
  770. // set clean workspace address
  771. if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
  772. auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
  773. for (const auto &index : clean_workspaces_indexes) {
  774. auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
  775. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  776. MS_EXCEPTION_IF_NULL(workspace);
  777. workspace->addr = device_address->ptr_;
  778. MS_EXCEPTION_IF_NULL(workspace->addr);
  779. workspace->size = device_address->size_;
  780. kernel_inputs->emplace_back(workspace);
  781. }
  782. }
  783. }
  784. bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
  785. auto &kernels = graph.execution_order();
  786. std::vector<DynamicKernelPtr> dynamic_kernel_list;
  787. auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
  788. if (iter != graph_dynamic_kernel_map_.end()) {
  789. dynamic_kernel_list = iter->second;
  790. }
  791. if (!dynamic_kernel_list.empty() && dynamic_kernel_list.size() != kernels.size()) {
  792. MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size()
  793. << " should be equal to the size of kernels " << kernels.size();
  794. }
  795. for (size_t i = 0; i < kernels.size(); ++i) {
  796. auto &kernel = kernels[i];
  797. if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr &&
  798. dynamic_kernel_list[i]->is_dynamic_shape() && AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL) {
  799. dynamic_kernel_list[i]->InferShape();
  800. dynamic_kernel_list[i]->UpdateArgs();
  801. dynamic_kernel_list[i]->Execute();
  802. if (!SyncStream()) {
  803. MS_LOG(ERROR) << "SyncStream failed";
  804. return false;
  805. }
  806. dynamic_kernel_list[i]->PostExecute();
  807. } else {
  808. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  809. MS_EXCEPTION_IF_NULL(kernel_mod);
  810. AddressPtrList kernel_inputs;
  811. AddressPtrList kernel_workspaces;
  812. AddressPtrList kernel_outputs;
  813. GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  814. auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  815. if (!ret) {
  816. MS_LOG(ERROR) << "Launch kernel failed.";
  817. return false;
  818. }
  819. KernelLaunchProfiling(kernels[i]->fullname_with_scope());
  820. }
  821. }
  822. return true;
  823. }
  824. bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
  825. MS_EXCEPTION_IF_NULL(graph);
  826. if (!LaunchKernelMod(*graph)) {
  827. MS_LOG(ERROR) << "LaunchKernelMod failed!";
  828. return false;
  829. }
  830. auto ms_context = MsContext::GetInstance();
  831. MS_EXCEPTION_IF_NULL(ms_context);
  832. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  833. if (!SyncStream()) {
  834. MS_LOG(ERROR) << "SyncStream failed";
  835. return false;
  836. }
  837. }
  838. return true;
  839. }
  840. void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
  841. const std::unordered_set<ValueNodePtr> &, const std::vector<CNodePtr> &) {
  842. MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
  843. }
  844. void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
  845. const std::unordered_set<ValueNodePtr> &value_nodes,
  846. const std::vector<CNodePtr> &execution_order) {
  847. // clear input parameter output address.
  848. for (const auto &input_node : inputs) {
  849. MS_EXCEPTION_IF_NULL(input_node);
  850. if (!input_node->isa<Parameter>()) {
  851. continue;
  852. }
  853. auto parameter = input_node->cast<ParameterPtr>();
  854. MS_EXCEPTION_IF_NULL(parameter);
  855. parameter->DecreaseUsedGraphCount();
  856. // Only the parameter has no graph used, then clear the output address.
  857. if (parameter->used_graph_count() != 0) {
  858. continue;
  859. }
  860. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
  861. if (!AnfAlgo::OutputAddrExist(input_node, index)) {
  862. continue;
  863. }
  864. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  865. }
  866. }
  867. // clear input value node output address.
  868. for (const auto &value_node : value_nodes) {
  869. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  870. continue;
  871. }
  872. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  873. }
  874. // clear cnode output address.
  875. for (const auto &cnode : execution_order) {
  876. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
  877. if (!AnfAlgo::OutputAddrExist(cnode, index)) {
  878. continue;
  879. }
  880. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  881. }
  882. }
  883. }
  884. bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr,
  885. const AddressPtrList &kernel_inputs,
  886. const AddressPtrList &kernel_outputs,
  887. const AddressPtrList &kernel_workspaces) const {
  888. MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
  889. auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  890. if (!ret) {
  891. MS_LOG(ERROR) << "Launch kernel failed.";
  892. return false;
  893. }
  894. return true;
  895. }
  896. DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type) {
  897. auto device_address = CreateDeviceAddress(nullptr, size, format, type);
  898. MS_EXCEPTION_IF_NULL(device_address);
  899. MS_EXCEPTION_IF_NULL(mem_manager_);
  900. auto base_ptr = mem_manager_->MallocMem(kStaticMem, size, device_address);
  901. MS_EXCEPTION_IF_NULL(base_ptr);
  902. return device_address;
  903. }
  904. } // namespace device
  905. } // namespace mindspore