| @@ -93,10 +93,20 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector< | |||
| runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node); | |||
| } | |||
| void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| return; | |||
| } | |||
| runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); | |||
| } | |||
| void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| SyncValueNodeDeviceAddr(kernel_graph); | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | |||
| @@ -130,6 +140,65 @@ void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||
| MS_LOG(INFO) << "Run graph end"; | |||
| } | |||
| void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) { | |||
| // Check if the graph cache exists. | |||
| if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { | |||
| return; | |||
| } | |||
| // Prepare the graph | |||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| SetKernelInfo(kernel_graph.get()); | |||
| BuildKernel(kernel_graph.get()); | |||
| run_op_graphs_[graph_info] = kernel_graph; | |||
| } | |||
| void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors) { | |||
| for (size_t i = 0; i < base_ref.size(); ++i) { | |||
| if (utils::isa<VectorRef>(base_ref[i])) { | |||
| auto ref_iter = utils::cast<VectorRef>(base_ref[i]); | |||
| SetOutputFlags(ref_iter, outputs_tensors); | |||
| } else if (utils::isa<tensor::TensorPtr>(base_ref[i])) { | |||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref[i]); | |||
| tensor_ptr->SetNeedWait(false); | |||
| outputs_tensors->push_back(tensor_ptr); | |||
| } | |||
| } | |||
| } | |||
| void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| runtime_.AssignKernelAddress(kernel_graph.get()); | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node; | |||
| runtime_.CreateOutputTensors(kernel_graph.get(), *input_tensors, outputs, &tensor_to_node); | |||
| runtime_.BindInputOutput(kernel_graph.get(), *input_tensors, outputs); | |||
| MS_LOG(INFO) << "Run Op start"; | |||
| auto execution_order = kernel_graph->execution_order(); | |||
| Reorder(&execution_order); | |||
| kernel_graph->set_execution_order(execution_order); | |||
| bool ret = runtime_.Run(kernel_graph.get(), false); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Run Op failed"; | |||
| } | |||
| std::vector<tensor::TensorPtr> output_tensors; | |||
| SetOutputFlags(*outputs, &output_tensors); | |||
| MS_LOG(INFO) << "Run Op end"; | |||
| } | |||
| void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto &kernel_nodes = kernel_graph->execution_order(); | |||
| @@ -38,10 +38,18 @@ class CPUSession : public SessionBasic { | |||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override; | |||
| void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| private: | |||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||
| void BuildKernel(const KernelGraph *kernel_graph); | |||
| void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors); | |||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| device::cpu::CPUKernelRuntime runtime_; | |||
| }; | |||
| MS_REG_SESSION(kCPUDevice, CPUSession); | |||
| @@ -994,6 +994,8 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex | |||
| }); | |||
| return; | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; | |||
| for (size_t i = 0; i < tensor_id_list.size(); ++i) { | |||
| auto tensor_id = tensor_id_list[i]; | |||
| @@ -1003,7 +1005,20 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex | |||
| std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { | |||
| tensor->set_shape(new_tensor->shape()); | |||
| tensor->set_data_type(new_tensor->data_type()); | |||
| tensor->set_device_address(new_tensor->device_address()); | |||
| if (target != kCPUDevice) { | |||
| tensor->set_device_address(new_tensor->device_address()); | |||
| } else { | |||
| auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||
| auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address()); | |||
| auto old_ptr = old_device_address->GetMutablePtr(); | |||
| auto new_ptr = new_device_address->GetPtr(); | |||
| MS_EXCEPTION_IF_NULL(old_ptr); | |||
| MS_EXCEPTION_IF_NULL(new_ptr); | |||
| auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize()); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret; | |||
| } | |||
| } | |||
| }); | |||
| } | |||
| } | |||
| @@ -1264,12 +1279,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati | |||
| MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms"; | |||
| auto ms_context = MsContext::GetInstance(); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true); | |||
| std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (device_target != kAscendDevice && device_target != kGPUDevice) { | |||
| MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; | |||
| } | |||
| if (session == nullptr) { | |||
| std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| session = session::SessionFactory::Get().Create(device_target); | |||
| MS_EXCEPTION_IF_NULL(session); | |||
| session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)); | |||
| @@ -56,6 +56,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph | |||
| } | |||
| auto tensor = node_value->cast<TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor->device_address() != nullptr) { | |||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), 0, | |||
| item_node.get()); | |||
| continue; | |||
| } | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | |||