| @@ -93,10 +93,20 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector< | |||||
| runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node); | runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node); | ||||
| } | } | ||||
| void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||||
| return; | |||||
| } | |||||
| runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); | |||||
| } | |||||
| void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs) { | VectorRef *outputs) { | ||||
| auto kernel_graph = GetGraph(graph_id); | auto kernel_graph = GetGraph(graph_id); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| SyncValueNodeDeviceAddr(kernel_graph); | |||||
| MS_LOG(INFO) << "Bind input output address"; | MS_LOG(INFO) << "Bind input output address"; | ||||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | ||||
| @@ -130,6 +140,65 @@ void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||||
| MS_LOG(INFO) << "Run graph end"; | MS_LOG(INFO) << "Run graph end"; | ||||
| } | } | ||||
| void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<int64_t> &tensors_mask) { | |||||
| // Check if the graph cache exists. | |||||
| if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { | |||||
| return; | |||||
| } | |||||
| // Prepare the graph | |||||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| SetKernelInfo(kernel_graph.get()); | |||||
| BuildKernel(kernel_graph.get()); | |||||
| run_op_graphs_[graph_info] = kernel_graph; | |||||
| } | |||||
| void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors) { | |||||
| for (size_t i = 0; i < base_ref.size(); ++i) { | |||||
| if (utils::isa<VectorRef>(base_ref[i])) { | |||||
| auto ref_iter = utils::cast<VectorRef>(base_ref[i]); | |||||
| SetOutputFlags(ref_iter, outputs_tensors); | |||||
| } else if (utils::isa<tensor::TensorPtr>(base_ref[i])) { | |||||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref[i]); | |||||
| tensor_ptr->SetNeedWait(false); | |||||
| outputs_tensors->push_back(tensor_ptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||||
| const std::vector<int64_t> &tensors_mask) { | |||||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| runtime_.AssignKernelAddress(kernel_graph.get()); | |||||
| std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node; | |||||
| runtime_.CreateOutputTensors(kernel_graph.get(), *input_tensors, outputs, &tensor_to_node); | |||||
| runtime_.BindInputOutput(kernel_graph.get(), *input_tensors, outputs); | |||||
| MS_LOG(INFO) << "Run Op start"; | |||||
| auto execution_order = kernel_graph->execution_order(); | |||||
| Reorder(&execution_order); | |||||
| kernel_graph->set_execution_order(execution_order); | |||||
| bool ret = runtime_.Run(kernel_graph.get(), false); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Run Op failed"; | |||||
| } | |||||
| std::vector<tensor::TensorPtr> output_tensors; | |||||
| SetOutputFlags(*outputs, &output_tensors); | |||||
| MS_LOG(INFO) << "Run Op end"; | |||||
| } | |||||
| void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { | void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto &kernel_nodes = kernel_graph->execution_order(); | auto &kernel_nodes = kernel_graph->execution_order(); | ||||
| @@ -38,10 +38,18 @@ class CPUSession : public SessionBasic { | |||||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override; | ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override; | ||||
| void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); | void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); | ||||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<int64_t> &tensors_mask) override; | |||||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||||
| const std::vector<int64_t> &tensors_mask) override; | |||||
| private: | private: | ||||
| void SetKernelInfo(const KernelGraph *kernel_graph); | void SetKernelInfo(const KernelGraph *kernel_graph); | ||||
| void BuildKernel(const KernelGraph *kernel_graph); | void BuildKernel(const KernelGraph *kernel_graph); | ||||
| void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors); | |||||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph); | |||||
| device::cpu::CPUKernelRuntime runtime_; | device::cpu::CPUKernelRuntime runtime_; | ||||
| }; | }; | ||||
| MS_REG_SESSION(kCPUDevice, CPUSession); | MS_REG_SESSION(kCPUDevice, CPUSession); | ||||
| @@ -994,6 +994,8 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex | |||||
| }); | }); | ||||
| return; | return; | ||||
| } | } | ||||
| auto ms_context = MsContext::GetInstance(); | |||||
| auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; | const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; | ||||
| for (size_t i = 0; i < tensor_id_list.size(); ++i) { | for (size_t i = 0; i < tensor_id_list.size(); ++i) { | ||||
| auto tensor_id = tensor_id_list[i]; | auto tensor_id = tensor_id_list[i]; | ||||
| @@ -1003,7 +1005,20 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex | |||||
| std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { | std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { | ||||
| tensor->set_shape(new_tensor->shape()); | tensor->set_shape(new_tensor->shape()); | ||||
| tensor->set_data_type(new_tensor->data_type()); | tensor->set_data_type(new_tensor->data_type()); | ||||
| tensor->set_device_address(new_tensor->device_address()); | |||||
| if (target != kCPUDevice) { | |||||
| tensor->set_device_address(new_tensor->device_address()); | |||||
| } else { | |||||
| auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||||
| auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address()); | |||||
| auto old_ptr = old_device_address->GetMutablePtr(); | |||||
| auto new_ptr = new_device_address->GetPtr(); | |||||
| MS_EXCEPTION_IF_NULL(old_ptr); | |||||
| MS_EXCEPTION_IF_NULL(new_ptr); | |||||
| auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize()); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret; | |||||
| } | |||||
| } | |||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1264,12 +1279,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati | |||||
| MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms"; | MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms"; | ||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true); | ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true); | ||||
| std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| if (device_target != kAscendDevice && device_target != kGPUDevice) { | |||||
| MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; | |||||
| } | |||||
| if (session == nullptr) { | if (session == nullptr) { | ||||
| std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| session = session::SessionFactory::Get().Create(device_target); | session = session::SessionFactory::Get().Create(device_target); | ||||
| MS_EXCEPTION_IF_NULL(session); | MS_EXCEPTION_IF_NULL(session); | ||||
| session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)); | session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)); | ||||
| @@ -56,6 +56,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph | |||||
| } | } | ||||
| auto tensor = node_value->cast<TensorPtr>(); | auto tensor = node_value->cast<TensorPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| if (tensor->device_address() != nullptr) { | |||||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), 0, | |||||
| item_node.get()); | |||||
| continue; | |||||
| } | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); | TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); | ||||
| if (output_type_id == kTypeUnknown) { | if (output_type_id == kTypeUnknown) { | ||||
| output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | ||||