| @@ -161,8 +161,12 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz | |||||
| } | } | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| address->ptr_ = tensor->data_c(true); | |||||
| address->ref_count_ = INIT_NODE_REF; | |||||
| if (address->ref_count_ > 0 && address->ptr_ != nullptr) { | |||||
| tensor->set_device_address(address); | |||||
| } else { | |||||
| address->ptr_ = tensor->data_c(true); | |||||
| address->ref_count_ = INIT_NODE_REF; | |||||
| } | |||||
| tensor->set_dirty(false); | tensor->set_dirty(false); | ||||
| return tensor; | return tensor; | ||||
| } else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) { | } else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) { | ||||
| @@ -211,6 +215,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||||
| } | } | ||||
| tensor->set_dirty(true); | tensor->set_dirty(true); | ||||
| } | } | ||||
| address->ref_count_ = INIT_NODE_REF; | address->ref_count_ = INIT_NODE_REF; | ||||
| tensor->set_device_address(address); | tensor->set_device_address(address); | ||||
| } | } | ||||
| @@ -220,7 +225,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||||
| // new output and bind ptr | // new output and bind ptr | ||||
| auto output_nodes = kernel_graph->outputs(); | auto output_nodes = kernel_graph->outputs(); | ||||
| for (const auto &item : output_nodes) { | for (const auto &item : output_nodes) { | ||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0); | |||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); | |||||
| auto out = CreatTensorForOutput(item_with_index.first, item_with_index.second, input_map); | auto out = CreatTensorForOutput(item_with_index.first, item_with_index.second, input_map); | ||||
| outputs->push_back(std::move(out)); | outputs->push_back(std::move(out)); | ||||
| } | } | ||||