|
|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include <utility> |
|
|
|
#include <functional> |
|
|
|
#include <unordered_map> |
|
|
|
#include <set> |
|
|
|
#include "kernel/kernel.h" |
|
|
|
#include "device/cpu/cpu_device_address.h" |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
@@ -139,8 +140,12 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t |
|
|
|
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); |
|
|
|
} |
|
|
|
|
|
|
|
BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, size_t index, |
|
|
|
const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map) { |
|
|
|
BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, |
|
|
|
const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map, |
|
|
|
std::set<DeviceAddressPtr> *bound_addresses, |
|
|
|
std::vector<tensor::TensorPtr> *need_sync_outputs) { |
|
|
|
auto &input_node = kernel_with_index.first; |
|
|
|
auto index = kernel_with_index.second; |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
if (input_node->isa<CNode>() && AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) { |
|
|
|
auto cnode = input_node->cast<CNodePtr>(); |
|
|
|
@@ -148,7 +153,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz |
|
|
|
VectorRef ret; |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) { |
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(cnode->input(i), 0); |
|
|
|
auto out = CreatTensorForOutput(item_with_index.first, item_with_index.second, input_map); |
|
|
|
auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); |
|
|
|
ret.push_back(out); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
@@ -169,11 +174,13 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz |
|
|
|
type_id = GetCPUSupportOutputTypeId(type_id); |
|
|
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
if (address->ref_count_ > 0 && address->ptr_ != nullptr) { |
|
|
|
if (bound_addresses->find(address) != bound_addresses->end()) { |
|
|
|
tensor->set_device_address(address); |
|
|
|
need_sync_outputs->emplace_back(tensor); |
|
|
|
} else { |
|
|
|
address->ptr_ = tensor->data_c(true); |
|
|
|
address->ref_count_ = INIT_NODE_REF; |
|
|
|
(void)bound_addresses->insert(address); |
|
|
|
} |
|
|
|
tensor->set_dirty(false); |
|
|
|
return tensor; |
|
|
|
@@ -187,7 +194,8 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz |
|
|
|
} |
|
|
|
|
|
|
|
void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { |
|
|
|
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs, |
|
|
|
std::vector<tensor::TensorPtr> *need_sync_outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(outputs); |
|
|
|
// bind input ptr |
|
|
|
@@ -195,10 +203,8 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, |
|
|
|
if (input_nodes.size() != inputs.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; |
|
|
|
} |
|
|
|
|
|
|
|
std::unordered_map<AnfNode *, tensor::TensorPtr> input_map; |
|
|
|
size_t input_idx = 0; |
|
|
|
size_t type_size = sizeof(float); |
|
|
|
for (auto &item : input_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
input_map[item.get()] = inputs[input_idx]; |
|
|
|
@@ -212,7 +218,8 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, |
|
|
|
(void)tensor->data_sync(); |
|
|
|
} |
|
|
|
std::vector<int> data_shape = tensor->shape(); |
|
|
|
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); |
|
|
|
size_t tensor_size = |
|
|
|
std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies<size_t>()); |
|
|
|
if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { |
|
|
|
address->ptr_ = tensor->data_c(false); |
|
|
|
} else { |
|
|
|
@@ -223,18 +230,17 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, |
|
|
|
} |
|
|
|
tensor->set_dirty(true); |
|
|
|
} |
|
|
|
|
|
|
|
address->ref_count_ = INIT_NODE_REF; |
|
|
|
tensor->set_device_address(address); |
|
|
|
} |
|
|
|
input_idx++; |
|
|
|
} |
|
|
|
|
|
|
|
// new output and bind ptr |
|
|
|
std::set<DeviceAddressPtr> bound_addresses; |
|
|
|
auto output_nodes = kernel_graph->outputs(); |
|
|
|
for (const auto &item : output_nodes) { |
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); |
|
|
|
auto out = CreatTensorForOutput(item_with_index.first, item_with_index.second, input_map); |
|
|
|
auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); |
|
|
|
outputs->push_back(std::move(out)); |
|
|
|
} |
|
|
|
} |
|
|
|
|