Browse Source

!27051 remove cpu kernel runtime unneccessary member

Merge pull request !27051 from kisnwang/clean-code
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
b9b64615f2
2 changed files with 32 additions and 25 deletions
  1. +24
    -19
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
  2. +8
    -6
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h

+ 24
- 19
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc View File

@@ -183,12 +183,11 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
}

tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
tensor::TensorPtr CPUKernelRuntime::CreateTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node,
size_t index, std::set<DeviceAddressPtr> *bound_addresses) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(tensor_to_node);
MS_EXCEPTION_IF_NULL(bound_addresses);
size_t output_size = AnfAlgo::GetOutputTensorNum(node);
if (index >= output_size) {
MS_LOG(EXCEPTION) << "For node " << node->DebugString() << ", index " << index << " exceed output size "
@@ -223,7 +222,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
}
tensor->set_device_address(address);
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
if (bound_addresses_.find(address) == bound_addresses_.end()) {
if (bound_addresses->find(address) == bound_addresses->end()) {
if (infer_type_id != device_type_id) {
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
ShapeVector data_shape = tensor->shape();
@@ -234,21 +233,23 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
} else {
tensor->set_sync_status(kNoNeedSync);
}
(void)bound_addresses_.insert(address);
(void)bound_addresses->insert(address);
}
session::KernelWithIndex node_index(node, index);
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
(*tensor_to_node)[tensor] = node_index;
return tensor;
}

BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph,
const session::KernelWithIndex &kernel_with_index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
BaseRef CPUKernelRuntime::GetOrCreateTensorForOutput(
session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
std::map<AnfNodePtr, tensor::TensorPtr> *input_param_tensor_map, std::set<DeviceAddressPtr> *bound_addresses) {
MS_EXCEPTION_IF_NULL(tensor_to_node);
MS_EXCEPTION_IF_NULL(input_param_tensor_map);
auto &input_node = kernel_with_index.first;
auto index = kernel_with_index.second;
MS_EXCEPTION_IF_NULL(input_node);

if (input_node->isa<CNode>()) {
auto node = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
@@ -256,15 +257,18 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_grap
VectorRef ret;
for (size_t i = 1; i < node->inputs().size(); i++) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node);
auto out = GetOrCreateTensorForOutput(kernel_graph, item_with_index, tensor_to_node, input_param_tensor_map,
bound_addresses);
ret.push_back(out);
}
return ret;
}
return CreatTensorForOutput(kernel_graph, node, index, tensor_to_node);
auto tensor = CreateTensorForOutput(kernel_graph, node, index, bound_addresses);
(*tensor_to_node)[tensor] = kernel_with_index;
return tensor;
} else if (input_node->isa<Parameter>()) {
auto iter = input_param_tensor_map_.find(input_node);
if (iter != input_param_tensor_map_.end()) {
auto iter = input_param_tensor_map->find(input_node);
if (iter != input_param_tensor_map->end()) {
return iter->second;
}
} else if (input_node->isa<ValueNode>()) {
@@ -286,21 +290,22 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
MS_LOG(EXCEPTION) << "Input size " << inputs.size() << " is not equal to input node size " << input_nodes.size();
}

std::map<AnfNodePtr, tensor::TensorPtr> input_param_tensor_map;
size_t input_idx = 0;
for (auto &item : input_nodes) {
MS_EXCEPTION_IF_NULL(item);
input_param_tensor_map_[item] = inputs[input_idx];
input_param_tensor_map[item] = inputs[input_idx];
input_idx++;
}

bound_addresses_.clear();
std::set<DeviceAddressPtr> bound_addresses;
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, false);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node);
auto out = GetOrCreateTensorForOutput(kernel_graph, item_with_index, tensor_to_node, &input_param_tensor_map,
&bound_addresses);
outputs->push_back(std::move(out));
}
input_param_tensor_map_.clear();
}

void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph,


+ 8
- 6
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h View File

@@ -57,18 +57,20 @@ class CPUKernelRuntime : public KernelRuntime {
const KernelWithIndex &node_index) const override;

private:
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
tensor::TensorPtr CreateTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
std::set<DeviceAddressPtr> *bound_addresses);
BaseRef GetOrCreateTensorForOutput(session::KernelGraph *kernel_graph,
const session::KernelWithIndex &kernel_with_index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
std::map<AnfNodePtr, tensor::TensorPtr> *input_param_tensor_map,
std::set<DeviceAddressPtr> *bound_addresses);
void BindInputTensorAddressPtr(const session::KernelGraph &graph, const std::vector<tensor::TensorPtr> &inputs);
void BindOutputTensorAddressPtr(const VectorRef *outputs);
void AssignValueNodeAddress(session::KernelGraph *kernel_graph);
void AssignInputNodeAddress(const session::KernelGraph *kernel_graph);
void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph);
void AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list);
std::set<DeviceAddressPtr> bound_addresses_;
std::map<AnfNodePtr, tensor::TensorPtr> input_param_tensor_map_;

bool initialized_{false};
};
} // namespace cpu


Loading…
Cancel
Save