Merge pull request !3872 from kisnwang/cache-internal-tensortags/v0.7.0-beta
| @@ -961,18 +961,40 @@ void KernelGraph::PrintGraphExecuteOrder() const { | |||
| } | |||
| } | |||
| void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { | |||
| void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx, | |||
| bool unique_target) { | |||
| if (front_node == nullptr || node == nullptr) { | |||
| MS_LOG(INFO) << "Front node or node is nullptr"; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); | |||
| front_to_internal_outputs_map_[front_node] = node; | |||
| int output_idx = 0; | |||
| if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { | |||
| output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>()); | |||
| } | |||
| internal_outputs_to_front_map_[node][output_idx] = front_node; | |||
| internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target); | |||
| } | |||
| void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| internal_outputs_tensor_map_[node][output_idx] = tensor; | |||
| } | |||
| tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, int output_idx) { | |||
| if (node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto iter = internal_outputs_tensor_map_.find(node); | |||
| if (iter == internal_outputs_tensor_map_.end()) { | |||
| return nullptr; | |||
| } | |||
| auto idx_iter = iter->second.find(output_idx); | |||
| if (idx_iter == iter->second.end()) { | |||
| return nullptr; | |||
| } | |||
| return idx_iter->second; | |||
| } | |||
| void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx, | |||
| @@ -996,7 +1018,7 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr | |||
| if (src_output_idx == -1) { | |||
| internal_outputs_to_front_map_[new_node] = front_nodes; | |||
| for (const auto &front_node_iter : front_nodes) { | |||
| front_to_internal_outputs_map_[front_node_iter.second] = new_node; | |||
| front_to_internal_outputs_map_[front_node_iter.second.first] = new_node; | |||
| } | |||
| internal_outputs_to_front_map_.erase(iter); | |||
| return; | |||
| @@ -1008,9 +1030,9 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr | |||
| MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node"; | |||
| return; | |||
| } | |||
| auto front_node = front_node_iter->second; | |||
| internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node; | |||
| front_to_internal_outputs_map_[front_node] = new_node; | |||
| auto front_node_pair = front_node_iter->second; | |||
| internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node_pair; | |||
| front_to_internal_outputs_map_[front_node_pair.first] = new_node; | |||
| front_nodes.erase(index); | |||
| if (front_nodes.empty()) { | |||
| internal_outputs_to_front_map_.erase(iter); | |||
| @@ -1027,16 +1049,30 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod | |||
| bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const { | |||
| auto front_nodes_iter = internal_outputs_to_front_map_.find(node); | |||
| if (front_nodes_iter != internal_outputs_to_front_map_.end()) { | |||
| if (output_idx == -1) { | |||
| return true; | |||
| } | |||
| auto &front_nodes = front_nodes_iter->second; | |||
| if (front_nodes.find(output_idx) != front_nodes.end()) { | |||
| return true; | |||
| } | |||
| if (front_nodes_iter == internal_outputs_to_front_map_.end()) { | |||
| return false; | |||
| } | |||
| return false; | |||
| if (output_idx == -1) { | |||
| return true; | |||
| } | |||
| auto &front_nodes = front_nodes_iter->second; | |||
| if (front_nodes.find(output_idx) == front_nodes.end()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const { | |||
| auto front_nodes_iter = internal_outputs_to_front_map_.find(node); | |||
| if (front_nodes_iter == internal_outputs_to_front_map_.end()) { | |||
| return false; | |||
| } | |||
| auto &front_nodes = front_nodes_iter->second; | |||
| auto idx_iter = front_nodes.find(output_idx); | |||
| if (idx_iter == front_nodes.end()) { | |||
| return false; | |||
| } | |||
| return idx_iter->second.second; | |||
| } | |||
| void KernelGraph::UpdateChildGraphOrder() { | |||
| @@ -143,11 +143,16 @@ class KernelGraph : public FuncGraph { | |||
| void PrintGraphExecuteOrder() const; | |||
| const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; } | |||
| void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } | |||
| void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); | |||
| void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx = 0, | |||
| bool unique_target = false); | |||
| void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1, | |||
| int dst_output_idx = -1); | |||
| AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; | |||
| bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const; | |||
| bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const; | |||
| void AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor); | |||
| tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, int output_idx); | |||
| uint32_t current_epoch() const { return current_epoch_; } | |||
| void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } | |||
| void UpdateChildGraphOrder(); | |||
| @@ -217,7 +222,8 @@ class KernelGraph : public FuncGraph { | |||
| CNodePtr end_goto_; | |||
| bool null_output_; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_; | |||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, AnfNodePtr>> internal_outputs_to_front_map_; | |||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_; | |||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; | |||
| uint32_t current_epoch_; | |||
| }; | |||
| } // namespace session | |||
| @@ -58,51 +58,38 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) { | |||
| return parameter->default_param(); | |||
| } | |||
| BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, | |||
| const DeviceAddressPtr &address) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; | |||
| // if node is a value node, no need sync addr from device to host | |||
| if (!AnfAlgo::OutputAddrExist(node, output_index)) { | |||
| if (node->isa<ValueNode>()) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| return value_node->value(); | |||
| } | |||
| if (node->isa<Parameter>()) { | |||
| for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { | |||
| if (input_idx >= input_tensors.size()) { | |||
| MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); | |||
| } | |||
| if (graph.inputs()[input_idx] == node) { | |||
| return input_tensors[input_idx]; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; | |||
| } | |||
| } | |||
| // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) | |||
| auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| auto shape = AnfAlgo::GetOutputInferShape(node, output_index); | |||
| TypeId type_id = kNumberTypeFloat32; | |||
| type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); | |||
| if (type_id == kTypeUnknown) { | |||
| type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | |||
| } | |||
| tensor::TensorPtr tensor; | |||
| std::vector<int> temp_shape; | |||
| if (graph.IsInternalOutput(node, output_index)) { | |||
| if (graph->IsUniqueTargetInternalOutput(node, output_index)) { | |||
| temp_shape.emplace_back(1); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| tensor->set_device_address(address); | |||
| tensor->set_dirty(false); | |||
| return tensor; | |||
| } | |||
| (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| tensor = graph->GetInternalOutputTensor(node, output_index); | |||
| if (tensor == nullptr) { | |||
| auto shape = AnfAlgo::GetOutputInferShape(node, output_index); | |||
| (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | |||
| tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| bool is_internal_output = graph->IsInternalOutput(node, output_index); | |||
| if (is_internal_output) { | |||
| graph->AddInternalOutputTensor(node, output_index, tensor); | |||
| } | |||
| } | |||
| // if in paynative mode,data only copyed to host when user want to print data | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { | |||
| tensor->set_device_address(address); | |||
| tensor->set_dirty(false); | |||
| @@ -114,7 +101,35 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| return tensor; | |||
| } | |||
| BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; | |||
| // if node is a value node, no need sync addr from device to host | |||
| if (!AnfAlgo::OutputAddrExist(node, output_index)) { | |||
| if (node->isa<ValueNode>()) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| return value_node->value(); | |||
| } | |||
| if (node->isa<Parameter>()) { | |||
| for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { | |||
| if (input_idx >= input_tensors.size()) { | |||
| MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); | |||
| } | |||
| if (graph->inputs()[input_idx] == node) { | |||
| return input_tensors[input_idx]; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; | |||
| } | |||
| } | |||
| auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); | |||
| return CreateOutputTensor(node, output_index, graph, address); | |||
| } | |||
| BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; | |||
| @@ -308,7 +323,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const | |||
| auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx); | |||
| auto ref_real_node = real_kernel.first; | |||
| auto ref_real_node_index = real_kernel.second; | |||
| if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) { | |||
| if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) { | |||
| auto kernel_info = ref_real_node->kernel_info(); | |||
| if (kernel_info == nullptr || !kernel_info->has_build_info()) { | |||
| MS_LOG(INFO) << "No kernel info"; | |||
| @@ -888,7 +903,7 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| for (auto &item : anf_outputs) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; | |||
| outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors)); | |||
| outputs->emplace_back(CreateTensorForOutput(item, kernel_graph, input_tensors)); | |||
| } | |||
| } | |||
| @@ -967,6 +982,71 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| summary_callback_(0, params_list); | |||
| } | |||
| namespace { | |||
| bool CNodePrimIsValueNode(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return false; | |||
| } | |||
| auto prim = cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, | |||
| const FuncGraphManagerPtr &front_func_graph_manager, | |||
| const std::shared_ptr<KernelGraph> &backend_graph) { | |||
| auto node_users = front_func_graph_manager->node_users(); | |||
| auto users = node_users[front_node]; | |||
| auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); | |||
| auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); | |||
| auto front_real_kernel = front_real_kernel_pair.first; | |||
| std::string kernel_target = GetCNodeTarget(front_real_kernel); | |||
| bool internal_output = CNodePrimIsValueNode(front_real_kernel); | |||
| bool unique_target = true; | |||
| if (internal_output && opt::IsNopNode(front_real_kernel)) { | |||
| auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); | |||
| auto pre_node_target = GetCNodeTarget(pre_node_pair.first); | |||
| if (pre_node_target != kernel_target) { | |||
| unique_target = false; | |||
| } | |||
| } | |||
| if (internal_output) { | |||
| for (auto user : users) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| auto prim = cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| if (!AnfAlgo::IsRealKernel(user.first)) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| if (kernel_target != GetCNodeTarget(user.first)) { | |||
| unique_target = false; | |||
| } | |||
| } | |||
| } | |||
| if (internal_output) { | |||
| MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << "To " | |||
| << backend_real_kernel_pair.first->DebugString(); | |||
| backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second, | |||
| unique_target); | |||
| } | |||
| } | |||
| } // namespace | |||
| CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> output_args; | |||
| @@ -982,9 +1062,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||
| return backend_anf; | |||
| } | |||
| auto front_real_kernel_pair = AnfAlgo::VisitKernel(out, 0); | |||
| auto front_real_kernel = front_real_kernel_pair.first; | |||
| auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_anf, 0); | |||
| MS_EXCEPTION_IF_NULL(out); | |||
| auto out_func_graph = out->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(out_func_graph); | |||
| @@ -992,51 +1070,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| if (out_func_graph_manager == nullptr) { | |||
| return backend_anf; | |||
| } | |||
| auto node_users = out_func_graph_manager->node_users(); | |||
| auto users = node_users[out]; | |||
| bool internal_output = true; | |||
| std::string kernel_target = GetCNodeTarget(front_real_kernel); | |||
| if (front_real_kernel != nullptr && front_real_kernel->isa<CNode>()) { | |||
| auto front_cnode = front_real_kernel->cast<CNodePtr>(); | |||
| if (front_cnode != nullptr) { | |||
| auto prim = front_cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| internal_output = false; | |||
| } | |||
| } else { | |||
| internal_output = false; | |||
| } | |||
| } | |||
| if (internal_output && opt::IsNopNode(front_real_kernel)) { | |||
| auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); | |||
| auto pre_node_target = GetCNodeTarget(pre_node_pair.first); | |||
| if (pre_node_target != kernel_target) { | |||
| internal_output = false; | |||
| } | |||
| } | |||
| if (internal_output) { | |||
| for (auto user : users) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| auto prim = cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (internal_output) { | |||
| MS_LOG(INFO) << "Internal output: " << out->DebugString() << "To " | |||
| << backend_real_kernel_pair.first->DebugString(); | |||
| graph->AddInternalOutput(out, backend_real_kernel_pair.first); | |||
| } | |||
| HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph); | |||
| return backend_anf; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; | |||
| @@ -20,7 +20,7 @@ | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include <functional> | |||
| #include <unordered_map> | |||
| #include <map> | |||
| #include <set> | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| @@ -124,11 +124,10 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t | |||
| return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); | |||
| } | |||
| tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, | |||
| std::set<DeviceAddressPtr> *bound_addresses, | |||
| tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, | |||
| size_t index, | |||
| std::vector<tensor::TensorPtr> *need_sync_outputs) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(bound_addresses); | |||
| MS_EXCEPTION_IF_NULL(need_sync_outputs); | |||
| size_t output_size = AnfAlgo::GetOutputTensorNum(node); | |||
| if (index >= output_size) { | |||
| @@ -136,14 +135,21 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s | |||
| } | |||
| auto address = AnfAlgo::GetMutableOutputAddr(node, index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| auto shape = AnfAlgo::GetOutputInferShape(node, index); | |||
| std::vector<int> temp_shape; | |||
| (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); | |||
| TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index); | |||
| TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (bound_addresses->find(address) != bound_addresses->end()) { | |||
| tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index); | |||
| if (tensor == nullptr) { | |||
| auto shape = AnfAlgo::GetOutputInferShape(node, index); | |||
| std::vector<int> temp_shape; | |||
| (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); | |||
| tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape); | |||
| bool is_internal_output = kernel_graph->IsInternalOutput(node, index); | |||
| if (is_internal_output) { | |||
| kernel_graph->AddInternalOutputTensor(node, index, tensor); | |||
| } | |||
| } | |||
| if (bound_addresses_.find(address) != bound_addresses_.end()) { | |||
| tensor->set_device_address(address); | |||
| need_sync_outputs->emplace_back(tensor); | |||
| } else { | |||
| @@ -159,15 +165,14 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s | |||
| address->ptr_ = tensor->data_c(); | |||
| } | |||
| address->ref_count_ = INIT_NODE_REF; | |||
| (void)bound_addresses->insert(address); | |||
| (void)bound_addresses_.insert(address); | |||
| } | |||
| tensor->set_dirty(false); | |||
| return tensor; | |||
| } | |||
| BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, | |||
| const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map, | |||
| std::set<DeviceAddressPtr> *bound_addresses, | |||
| BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, | |||
| const session::KernelWithIndex &kernel_with_index, | |||
| std::vector<tensor::TensorPtr> *need_sync_outputs) { | |||
| auto &input_node = kernel_with_index.first; | |||
| auto index = kernel_with_index.second; | |||
| @@ -179,15 +184,15 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k | |||
| VectorRef ret; | |||
| for (size_t i = 1; i < node->inputs().size(); i++) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); | |||
| auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); | |||
| auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); | |||
| ret.push_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); | |||
| return CreatTensorForOutput(kernel_graph, node, index, need_sync_outputs); | |||
| } else if (input_node->isa<Parameter>()) { | |||
| auto iter = input_map.find(input_node.get()); | |||
| if (iter != input_map.end()) { | |||
| auto iter = input_param_tensor_map_.find(input_node); | |||
| if (iter != input_param_tensor_map_.end()) { | |||
| return iter->second; | |||
| } | |||
| } else if (input_node->isa<ValueNode>()) { | |||
| @@ -197,10 +202,8 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k | |||
| } | |||
| return BaseRef(); | |||
| } | |||
| void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs, | |||
| std::vector<tensor::TensorPtr> *need_sync_outputs) { | |||
| void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| // bind input ptr | |||
| @@ -208,11 +211,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||
| if (input_nodes.size() != inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | |||
| } | |||
| std::unordered_map<AnfNode *, tensor::TensorPtr> input_map; | |||
| input_param_tensor_map_.clear(); | |||
| size_t input_idx = 0; | |||
| for (auto &item : input_nodes) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| input_map[item.get()] = inputs[input_idx]; | |||
| input_param_tensor_map_[item] = inputs[input_idx]; | |||
| if (item->isa<Parameter>()) { | |||
| auto address = AnfAlgo::GetMutableOutputAddr(item, 0); | |||
| auto tensor = inputs[input_idx]; | |||
| @@ -222,7 +225,6 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||
| if (tensor_address != nullptr && tensor_address != address) { | |||
| (void)tensor->data_sync(); | |||
| } | |||
| if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 || | |||
| tensor->data_type() == kNumberTypeInt32) { | |||
| address->ptr_ = tensor->data_c(); | |||
| @@ -243,11 +245,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||
| input_idx++; | |||
| } | |||
| // new output and bind ptr | |||
| std::set<DeviceAddressPtr> bound_addresses; | |||
| bound_addresses_.clear(); | |||
| auto output_nodes = kernel_graph->outputs(); | |||
| for (const auto &item : output_nodes) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); | |||
| auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); | |||
| auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); | |||
| outputs->push_back(std::move(out)); | |||
| } | |||
| } | |||
| @@ -19,7 +19,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <map> | |||
| #include <set> | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| @@ -38,7 +38,7 @@ class CPUKernelRuntime : public KernelRuntime { | |||
| bool Init() override { return true; } | |||
| bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||
| void AssignKernelAddress(session::KernelGraph *kernel_graph); | |||
| void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs); | |||
| void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| @@ -49,19 +49,18 @@ class CPUKernelRuntime : public KernelRuntime { | |||
| TypeId type_id) override; | |||
| private: | |||
| tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, | |||
| std::set<DeviceAddressPtr> *bound_addresses, | |||
| tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, | |||
| std::vector<tensor::TensorPtr> *need_sync_outputs); | |||
| BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, | |||
| const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map, | |||
| std::set<DeviceAddressPtr> *bound_addresses, | |||
| BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index, | |||
| std::vector<tensor::TensorPtr> *need_sync_outputs); | |||
| void AssignValueNodeAddress(session::KernelGraph *kernel_graph); | |||
| void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); | |||
| void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); | |||
| void AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list); | |||
| CPUResourceManager resource_manager_; | |||
| std::set<DeviceAddressPtr> bound_addresses_; | |||
| std::map<AnfNodePtr, tensor::TensorPtr> input_param_tensor_map_; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||