You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cpu_kernel_runtime.cc 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/cpu/cpu_kernel_runtime.h"
  17. #include <string>
  18. #include <vector>
  19. #include <memory>
  20. #include <numeric>
  21. #include <utility>
  22. #include <functional>
  23. #include <unordered_map>
  24. #include <set>
  25. #include "kernel/kernel.h"
  26. #include "device/cpu/cpu_device_address.h"
  27. #include "utils/context/ms_context.h"
  28. #include "utils/config_manager.h"
  29. #include "common/utils.h"
  30. #include "session/anf_runtime_algorithm.h"
  31. #include "session/session_basic.h"
  32. #include "operator/ops.h"
  33. namespace mindspore {
  34. namespace device {
  35. namespace cpu {
  36. const size_t INIT_NODE_REF = 1;
  37. namespace {
  38. TypeId GetCPUSupportOutputTypeId(const TypeId type_id) {
  39. TypeId support_type_id = type_id;
  40. if (type_id == kNumberTypeUInt32) {
  41. support_type_id = kNumberTypeInt32;
  42. }
  43. if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
  44. type_id == kNumberTypeFloat64) {
  45. support_type_id = kNumberTypeFloat32;
  46. }
  47. if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) {
  48. MS_LOG(EXCEPTION) << "Check output type failed.";
  49. }
  50. return support_type_id;
  51. }
  52. } // namespace
  53. void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
  54. AssignValueNodeAddress(kernel_graph);
  55. AssignInputNodeAddress(kernel_graph);
  56. AssignKernelOutputAddress(kernel_graph);
  57. resource_manager_.MemPlan(kernel_graph);
  58. resource_manager_.MemMalloc(kernel_graph);
  59. }
  60. void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) {
  61. MS_EXCEPTION_IF_NULL(kernel_graph);
  62. size_t type_size = sizeof(float);
  63. for (auto &item_node : kernel_graph->graph_value_nodes()) {
  64. MS_EXCEPTION_IF_NULL(item_node);
  65. if (item_node->isa<ValueNode>()) {
  66. auto value_node = item_node->cast<ValueNodePtr>();
  67. MS_EXCEPTION_IF_NULL(value_node);
  68. auto node_value = value_node->value();
  69. MS_EXCEPTION_IF_NULL(node_value);
  70. if (!node_value->isa<tensor::Tensor>()) {
  71. continue;
  72. }
  73. auto tensor = node_value->cast<TensorPtr>();
  74. MS_EXCEPTION_IF_NULL(tensor);
  75. std::vector<int> data_shape = tensor->shape();
  76. size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
  77. DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32);
  78. if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
  79. address->ptr_ = tensor->data_c(false);
  80. } else {
  81. address->ptr_ = resource_manager_.MemMalloc(tensor_size);
  82. if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
  83. tensor->data_c(false))) {
  84. MS_LOG(EXCEPTION) << "Value node sync host to device failed!";
  85. }
  86. }
  87. address->ref_count_ = INIT_NODE_REF;
  88. AnfAlgo::SetOutputAddr(address, 0, item_node.get());
  89. }
  90. }
  91. }
  92. void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
  93. MS_EXCEPTION_IF_NULL(kernel_graph);
  94. size_t type_size = sizeof(float);
  95. for (auto &item : kernel_graph->inputs()) {
  96. MS_EXCEPTION_IF_NULL(item);
  97. if (item->isa<Parameter>()) {
  98. auto output_num = AnfAlgo::GetOutputTensorNum(item);
  99. for (size_t index = 0; index < output_num; index++) {
  100. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  101. std::vector<size_t> fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index);
  102. size_t tensor_size =
  103. fmt_shape.empty() ? type_size
  104. : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
  105. auto format = AnfAlgo::GetOutputFormat(item, index);
  106. auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id);
  107. AnfAlgo::SetOutputAddr(address, index, item.get());
  108. }
  109. }
  110. }
  111. }
  112. void CPUKernelRuntime::AssignKernelOutputAddress(const session::KernelGraph *kernel_graph) {
  113. MS_EXCEPTION_IF_NULL(kernel_graph);
  114. auto kernels = kernel_graph->execution_order();
  115. for (auto &kernel : kernels) {
  116. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  117. MS_EXCEPTION_IF_NULL(kernel_mod);
  118. auto output_sizes = kernel_mod->GetOutputSizeList();
  119. for (size_t i = 0; i < output_sizes.size(); ++i) {
  120. auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
  121. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  122. AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
  123. kernel.get());
  124. }
  125. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  126. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  127. AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
  128. i, kernel.get());
  129. }
  130. }
  131. }
  132. DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
  133. TypeId type_id) {
  134. return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
  135. }
  136. BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index,
  137. const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map,
  138. std::set<DeviceAddressPtr> *bound_addresses,
  139. std::vector<tensor::TensorPtr> *need_sync_outputs) {
  140. auto &input_node = kernel_with_index.first;
  141. auto index = kernel_with_index.second;
  142. MS_EXCEPTION_IF_NULL(input_node);
  143. if (input_node->isa<CNode>()) {
  144. auto node = input_node->cast<CNodePtr>();
  145. MS_EXCEPTION_IF_NULL(node);
  146. if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) {
  147. VectorRef ret;
  148. for (size_t i = 1; i < node->inputs().size(); i++) {
  149. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0);
  150. auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs);
  151. ret.push_back(out);
  152. }
  153. return ret;
  154. }
  155. size_t output_size = AnfAlgo::GetOutputTensorNum(node);
  156. if (index >= output_size) {
  157. MS_LOG(EXCEPTION) << "Invalid input index " << index;
  158. }
  159. auto address = AnfAlgo::GetMutableOutputAddr(node, index);
  160. MS_EXCEPTION_IF_NULL(address);
  161. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  162. std::vector<int> temp_shape;
  163. (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
  164. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  165. type_id = GetCPUSupportOutputTypeId(type_id);
  166. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  167. MS_EXCEPTION_IF_NULL(tensor);
  168. if (bound_addresses->find(address) != bound_addresses->end()) {
  169. tensor->set_device_address(address);
  170. need_sync_outputs->emplace_back(tensor);
  171. } else {
  172. address->ptr_ = tensor->data_c(true);
  173. address->ref_count_ = INIT_NODE_REF;
  174. (void)bound_addresses->insert(address);
  175. }
  176. tensor->set_dirty(false);
  177. return tensor;
  178. } else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
  179. auto iter = input_map.find(input_node.get());
  180. if (iter != input_map.end()) {
  181. return iter->second;
  182. }
  183. }
  184. return BaseRef();
  185. }
  186. void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
  187. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs,
  188. std::vector<tensor::TensorPtr> *need_sync_outputs) {
  189. MS_EXCEPTION_IF_NULL(kernel_graph);
  190. MS_EXCEPTION_IF_NULL(outputs);
  191. // bind input ptr
  192. auto &input_nodes = kernel_graph->inputs();
  193. if (input_nodes.size() != inputs.size()) {
  194. MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
  195. }
  196. std::unordered_map<AnfNode *, tensor::TensorPtr> input_map;
  197. size_t input_idx = 0;
  198. for (auto &item : input_nodes) {
  199. MS_EXCEPTION_IF_NULL(item);
  200. input_map[item.get()] = inputs[input_idx];
  201. if (item->isa<Parameter>()) {
  202. auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
  203. auto tensor = inputs[input_idx];
  204. auto tensor_address = tensor->device_address();
  205. MS_EXCEPTION_IF_NULL(address);
  206. MS_EXCEPTION_IF_NULL(tensor);
  207. if (tensor_address != nullptr && tensor_address != address) {
  208. (void)tensor->data_sync();
  209. }
  210. std::vector<int> data_shape = tensor->shape();
  211. size_t tensor_size =
  212. std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies<size_t>());
  213. if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
  214. address->ptr_ = tensor->data_c(false);
  215. } else {
  216. address->ptr_ = resource_manager_.MemMalloc(tensor_size);
  217. if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
  218. tensor->data_c(false))) {
  219. MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
  220. }
  221. tensor->set_dirty(true);
  222. }
  223. address->ref_count_ = INIT_NODE_REF;
  224. tensor->set_device_address(address);
  225. }
  226. input_idx++;
  227. }
  228. // new output and bind ptr
  229. std::set<DeviceAddressPtr> bound_addresses;
  230. auto output_nodes = kernel_graph->outputs();
  231. for (const auto &item : output_nodes) {
  232. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
  233. auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs);
  234. outputs->push_back(std::move(out));
  235. }
  236. }
  237. void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list) {
  238. MS_EXCEPTION_IF_NULL(address);
  239. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  240. MS_EXCEPTION_IF_NULL(input);
  241. if (address->ptr_ == nullptr) {
  242. address->ptr_ = resource_manager_.MemMalloc(address->size_);
  243. }
  244. MS_EXCEPTION_IF_NULL(address->ptr_);
  245. input->addr = address->ptr_;
  246. input->size = address->size_;
  247. input_list->push_back(input);
  248. }
  249. void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
  250. resource_manager_.IncreaseSummaryRefCount(summary_outputs);
  251. }
  252. void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
  253. resource_manager_.DecreaseSummaryRefCount(summary_outputs);
  254. }
  255. bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) {
  256. MS_EXCEPTION_IF_NULL(kernel_graph);
  257. resource_manager_.IncreaseAddressRefCount(kernel_graph);
  258. auto kernels = kernel_graph->execution_order();
  259. for (const auto &kernel : kernels) {
  260. std::vector<kernel::AddressPtr> kernel_inputs;
  261. std::vector<kernel::AddressPtr> kernel_workspaces;
  262. std::vector<kernel::AddressPtr> kernel_outputs;
  263. size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
  264. for (size_t i = 0; i < input_num; ++i) {
  265. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i).get();
  266. MS_EXCEPTION_IF_NULL(device_address);
  267. AddRuntimeAddress(device_address, &kernel_inputs);
  268. }
  269. size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
  270. for (size_t i = 0; i < output_num; ++i) {
  271. auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i).get();
  272. MS_EXCEPTION_IF_NULL(device_address);
  273. AddRuntimeAddress(device_address, &kernel_outputs);
  274. }
  275. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  276. MS_EXCEPTION_IF_NULL(kernel_mod);
  277. for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
  278. auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
  279. MS_EXCEPTION_IF_NULL(device_address);
  280. AddRuntimeAddress(device_address, &kernel_workspaces);
  281. }
  282. auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0);
  283. resource_manager_.DecreaseAddressRefCount(kernel);
  284. if (!ret) {
  285. MS_LOG(EXCEPTION) << "Launch kernel failed.";
  286. }
  287. }
  288. return true;
  289. }
  290. } // namespace cpu
  291. } // namespace device
  292. } // namespace mindspore