You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpu_kernel_runtime.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/cpu/cpu_kernel_runtime.h"
  17. #include <string>
  18. #include <vector>
  19. #include <memory>
  20. #include <numeric>
  21. #include <utility>
  22. #include <functional>
  23. #include <unordered_map>
  24. #include <set>
  25. #include "kernel/kernel.h"
  26. #include "device/cpu/cpu_device_address.h"
  27. #include "utils/context/ms_context.h"
  28. #include "utils/config_manager.h"
  29. #include "utils/profile.h"
  30. #include "common/utils.h"
  31. #include "session/anf_runtime_algorithm.h"
  32. #include "session/session_basic.h"
  33. #include "operator/ops.h"
  34. namespace mindspore {
  35. namespace device {
  36. namespace cpu {
  37. const size_t INIT_NODE_REF = 1;
  38. namespace {
  39. TypeId GetCPUSupportOutputTypeId(const TypeId type_id) {
  40. TypeId support_type_id = type_id;
  41. if (type_id == kNumberTypeUInt32) {
  42. support_type_id = kNumberTypeInt32;
  43. }
  44. if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
  45. type_id == kNumberTypeFloat64) {
  46. support_type_id = kNumberTypeFloat32;
  47. }
  48. if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) {
  49. MS_LOG(EXCEPTION) << "Check output type failed.";
  50. }
  51. return support_type_id;
  52. }
  53. } // namespace
  54. void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
  55. AssignValueNodeAddress(kernel_graph);
  56. AssignInputNodeAddress(kernel_graph);
  57. AssignKernelOutputAddress(kernel_graph);
  58. resource_manager_.MemPlan(kernel_graph);
  59. resource_manager_.MemMalloc(kernel_graph);
  60. }
  61. void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) {
  62. MS_EXCEPTION_IF_NULL(kernel_graph);
  63. size_t type_size = sizeof(float);
  64. for (auto &item_node : kernel_graph->graph_value_nodes()) {
  65. MS_EXCEPTION_IF_NULL(item_node);
  66. if (item_node->isa<ValueNode>()) {
  67. auto value_node = item_node->cast<ValueNodePtr>();
  68. MS_EXCEPTION_IF_NULL(value_node);
  69. auto node_value = value_node->value();
  70. MS_EXCEPTION_IF_NULL(node_value);
  71. if (!node_value->isa<tensor::Tensor>()) {
  72. continue;
  73. }
  74. auto tensor = node_value->cast<TensorPtr>();
  75. MS_EXCEPTION_IF_NULL(tensor);
  76. std::vector<int> data_shape = tensor->shape();
  77. size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
  78. DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32);
  79. MS_EXCEPTION_IF_NULL(address);
  80. if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
  81. address->ptr_ = tensor->data_c();
  82. } else {
  83. address->ptr_ = resource_manager_.MemMalloc(tensor_size);
  84. if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
  85. tensor->data_c())) {
  86. MS_LOG(EXCEPTION) << "Value node sync host to device failed!";
  87. }
  88. }
  89. address->ref_count_ = INIT_NODE_REF;
  90. AnfAlgo::SetOutputAddr(address, 0, item_node.get());
  91. }
  92. }
  93. }
  94. void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
  95. MS_EXCEPTION_IF_NULL(kernel_graph);
  96. size_t type_size = sizeof(float);
  97. for (auto &item : kernel_graph->inputs()) {
  98. MS_EXCEPTION_IF_NULL(item);
  99. if (item->isa<Parameter>()) {
  100. auto output_num = AnfAlgo::GetOutputTensorNum(item);
  101. for (size_t index = 0; index < output_num; index++) {
  102. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  103. std::vector<size_t> fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index);
  104. size_t tensor_size =
  105. fmt_shape.empty() ? type_size
  106. : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
  107. auto format = AnfAlgo::GetOutputFormat(item, index);
  108. auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id);
  109. AnfAlgo::SetOutputAddr(address, index, item.get());
  110. }
  111. }
  112. }
  113. }
  114. void CPUKernelRuntime::AssignKernelOutputAddress(const session::KernelGraph *kernel_graph) {
  115. MS_EXCEPTION_IF_NULL(kernel_graph);
  116. auto kernels = kernel_graph->execution_order();
  117. for (auto &kernel : kernels) {
  118. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  119. MS_EXCEPTION_IF_NULL(kernel_mod);
  120. auto output_sizes = kernel_mod->GetOutputSizeList();
  121. for (size_t i = 0; i < output_sizes.size(); ++i) {
  122. auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
  123. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  124. AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
  125. kernel.get());
  126. }
  127. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  128. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  129. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
  130. i, kernel.get());
  131. }
  132. }
  133. }
  134. DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
  135. TypeId type_id) {
  136. return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
  137. }
  138. tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index,
  139. std::set<DeviceAddressPtr> *bound_addresses,
  140. std::vector<tensor::TensorPtr> *need_sync_outputs) {
  141. MS_EXCEPTION_IF_NULL(node);
  142. MS_EXCEPTION_IF_NULL(bound_addresses);
  143. MS_EXCEPTION_IF_NULL(need_sync_outputs);
  144. size_t output_size = AnfAlgo::GetOutputTensorNum(node);
  145. if (index >= output_size) {
  146. MS_LOG(EXCEPTION) << "Invalid input index " << index;
  147. }
  148. auto address = AnfAlgo::GetMutableOutputAddr(node, index);
  149. MS_EXCEPTION_IF_NULL(address);
  150. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  151. std::vector<int> temp_shape;
  152. (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
  153. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  154. type_id = GetCPUSupportOutputTypeId(type_id);
  155. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  156. MS_EXCEPTION_IF_NULL(tensor);
  157. if (bound_addresses->find(address) != bound_addresses->end()) {
  158. tensor->set_device_address(address);
  159. need_sync_outputs->emplace_back(tensor);
  160. } else {
  161. address->ptr_ = tensor->data_c();
  162. address->ref_count_ = INIT_NODE_REF;
  163. (void)bound_addresses->insert(address);
  164. }
  165. tensor->set_dirty(false);
  166. return tensor;
  167. }
  168. BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index,
  169. const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map,
  170. std::set<DeviceAddressPtr> *bound_addresses,
  171. std::vector<tensor::TensorPtr> *need_sync_outputs) {
  172. auto &input_node = kernel_with_index.first;
  173. auto index = kernel_with_index.second;
  174. MS_EXCEPTION_IF_NULL(input_node);
  175. if (input_node->isa<CNode>()) {
  176. auto node = input_node->cast<CNodePtr>();
  177. MS_EXCEPTION_IF_NULL(node);
  178. if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) {
  179. VectorRef ret;
  180. for (size_t i = 1; i < node->inputs().size(); i++) {
  181. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0);
  182. auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs);
  183. ret.push_back(out);
  184. }
  185. return ret;
  186. }
  187. return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs);
  188. } else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
  189. auto iter = input_map.find(input_node.get());
  190. if (iter != input_map.end()) {
  191. return iter->second;
  192. }
  193. }
  194. return BaseRef();
  195. }
  196. void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
  197. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs,
  198. std::vector<tensor::TensorPtr> *need_sync_outputs) {
  199. MS_EXCEPTION_IF_NULL(kernel_graph);
  200. MS_EXCEPTION_IF_NULL(outputs);
  201. // bind input ptr
  202. auto &input_nodes = kernel_graph->inputs();
  203. if (input_nodes.size() != inputs.size()) {
  204. MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
  205. }
  206. std::unordered_map<AnfNode *, tensor::TensorPtr> input_map;
  207. size_t input_idx = 0;
  208. for (auto &item : input_nodes) {
  209. MS_EXCEPTION_IF_NULL(item);
  210. input_map[item.get()] = inputs[input_idx];
  211. if (item->isa<Parameter>()) {
  212. auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
  213. auto tensor = inputs[input_idx];
  214. auto tensor_address = tensor->device_address();
  215. MS_EXCEPTION_IF_NULL(address);
  216. MS_EXCEPTION_IF_NULL(tensor);
  217. if (tensor_address != nullptr && tensor_address != address) {
  218. (void)tensor->data_sync();
  219. }
  220. std::vector<int> data_shape = tensor->shape();
  221. size_t tensor_size =
  222. std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies<size_t>());
  223. if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
  224. address->ptr_ = tensor->data_c();
  225. } else {
  226. address->ptr_ = resource_manager_.MemMalloc(tensor_size);
  227. if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
  228. tensor->data_c())) {
  229. MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
  230. }
  231. tensor->set_dirty(true);
  232. }
  233. address->ref_count_ = INIT_NODE_REF;
  234. tensor->set_device_address(address);
  235. }
  236. input_idx++;
  237. }
  238. // new output and bind ptr
  239. std::set<DeviceAddressPtr> bound_addresses;
  240. auto output_nodes = kernel_graph->outputs();
  241. for (const auto &item : output_nodes) {
  242. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
  243. auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs);
  244. outputs->push_back(std::move(out));
  245. }
  246. }
  247. void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list) {
  248. MS_EXCEPTION_IF_NULL(address);
  249. MS_EXCEPTION_IF_NULL(input_list);
  250. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  251. MS_EXCEPTION_IF_NULL(input);
  252. if (address->ptr_ == nullptr) {
  253. address->ptr_ = resource_manager_.MemMalloc(address->size_);
  254. }
  255. MS_EXCEPTION_IF_NULL(address->ptr_);
  256. input->addr = address->ptr_;
  257. input->size = address->size_;
  258. input_list->push_back(input);
  259. }
  260. void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
  261. resource_manager_.IncreaseSummaryRefCount(summary_outputs);
  262. }
  263. void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
  264. resource_manager_.DecreaseSummaryRefCount(summary_outputs);
  265. }
  266. bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) {
  267. MS_EXCEPTION_IF_NULL(kernel_graph);
  268. resource_manager_.IncreaseAddressRefCount(kernel_graph);
  269. auto kernels = kernel_graph->execution_order();
  270. for (const auto &kernel : kernels) {
  271. #ifdef ENABLE_PROFILE
  272. double start_time = GetTime();
  273. #endif
  274. std::vector<kernel::AddressPtr> kernel_inputs;
  275. std::vector<kernel::AddressPtr> kernel_workspaces;
  276. std::vector<kernel::AddressPtr> kernel_outputs;
  277. size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
  278. for (size_t i = 0; i < input_num; ++i) {
  279. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i).get();
  280. MS_EXCEPTION_IF_NULL(device_address);
  281. AddRuntimeAddress(device_address, &kernel_inputs);
  282. }
  283. size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
  284. for (size_t i = 0; i < output_num; ++i) {
  285. auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i).get();
  286. MS_EXCEPTION_IF_NULL(device_address);
  287. AddRuntimeAddress(device_address, &kernel_outputs);
  288. }
  289. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  290. MS_EXCEPTION_IF_NULL(kernel_mod);
  291. for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
  292. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  293. MS_EXCEPTION_IF_NULL(device_address);
  294. AddRuntimeAddress(device_address, &kernel_workspaces);
  295. }
  296. auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0);
  297. resource_manager_.DecreaseAddressRefCount(kernel);
  298. if (!ret) {
  299. MS_LOG(EXCEPTION) << "Launch kernel failed.";
  300. }
  301. #ifdef ENABLE_PROFILE
  302. double cost_time = GetTime() - start_time;
  303. MS_LOG(INFO) << "cpu kernel: " << kernel->fullname_with_scope() << " costs " << cost_time * 1e6 << " us";
  304. #endif
  305. }
  306. return true;
  307. }
  308. } // namespace cpu
  309. } // namespace device
  310. } // namespace mindspore