| @@ -202,7 +202,9 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr | |||
| std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, 0, output_format, output_type); | |||
| stub_output_tensor->set_device_address(device_address); | |||
| output_tensor_info.output_stub_tensor = stub_output_tensor; | |||
| output_tensor_info.is_weight = !dynamic_cast<device::KernelInfo *>(output_node->kernel_info())->is_feature_map(); | |||
| auto kernel_info = dynamic_cast<const device::KernelInfo *>(output_node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| output_tensor_info.is_weight = !(kernel_info->is_feature_map()); | |||
| (*op_output_info)[kernel_with_index] = output_tensor_info; | |||
| } | |||
| } | |||
| @@ -113,20 +113,10 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector< | |||
| runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node); | |||
| } | |||
| void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| return; | |||
| } | |||
| runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); | |||
| } | |||
| void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| SyncValueNodeDeviceAddr(kernel_graph); | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | |||
| @@ -50,7 +50,6 @@ class CPUSession : public SessionBasic { | |||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||
| void BuildKernel(const KernelGraph *kernel_graph); | |||
| void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors); | |||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| device::cpu::CPUKernelRuntime runtime_; | |||
| }; | |||
| MS_REG_SESSION(kCPUDevice, CPUSession); | |||
| @@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||
| VectorRef *outputs) { | |||
| auto &kernel_graph = graphs_[graph_id]; | |||
| MS_LOG(INFO) << "RunGraph graph_id: " << graph_id; | |||
| // In pynative mode, device addresses of tensors in value nodes change. | |||
| SyncValueNodeDeviceAddr(kernel_graph); | |||
| // Load input data from user input | |||
| LoadInputData(kernel_graph, inputs); | |||
| if (debugger_) { | |||
| @@ -449,8 +447,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||
| #endif | |||
| Execute(kernel_graph); | |||
| } | |||
| // In pynative mode, device addresses of tensors in value nodes need be clean. | |||
| CleanValueNodeDeviceAddr(kernel_graph); | |||
| // Summary | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -540,28 +536,6 @@ void GPUSession::PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| } | |||
| } | |||
| void GPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| return; | |||
| } | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->SyncValueNodeDeviceAddr(kernel_graph.get()); | |||
| } | |||
| void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| return; | |||
| } | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get()); | |||
| } | |||
| void GPUSession::SyncStream() { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| @@ -82,10 +82,6 @@ class GPUSession : public SessionBasic { | |||
| void PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| GraphId CompileGraphImpl(KernelGraphPtr kernel_graph); | |||
| }; | |||
| using GPUSessionPtr = std::shared_ptr<GPUSession>; | |||
| @@ -404,7 +404,7 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { | |||
| return true; | |||
| } | |||
| void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| std::map<AnfNodePtr, size_t> *parameter_index) { | |||
| size_t index = 0; | |||
| for (const auto &input_node : graph->inputs()) { | |||
| @@ -512,7 +512,7 @@ void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vect | |||
| } | |||
| } | |||
| void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) { | |||
| void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| for (const auto &kernel : graph->execution_order()) { | |||
| for (size_t i = 1; i < kernel->inputs().size(); i += 1) { | |||
| @@ -712,52 +712,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||
| MS_LOG(INFO) << "AssignStaticMemoryValueNode end"; | |||
| } | |||
| void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "SyncValueNodeDeviceAddr start"; | |||
| for (auto &value_node : graph->graph_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &node_value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(node_value); | |||
| if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) { | |||
| continue; | |||
| } | |||
| std::vector<tensor::TensorPtr> tensors; | |||
| TensorValueToTensor(node_value, &tensors); | |||
| for (size_t index = 0; index < tensors.size(); index += 1) { | |||
| const auto &tensor = tensors[index]; | |||
| if (tensor->device_address() != nullptr) { | |||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), index, | |||
| value_node.get()); | |||
| } else { | |||
| MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr."; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "SyncValueNodeDeviceAddr end"; | |||
| } | |||
| void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "CleanValueNodeDeviceAddr start"; | |||
| for (auto &value_node : graph->graph_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &node_value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(node_value); | |||
| if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) { | |||
| continue; | |||
| } | |||
| std::vector<tensor::TensorPtr> tensors; | |||
| TensorValueToTensor(node_value, &tensors); | |||
| for (size_t index = 0; index < tensors.size(); index += 1) { | |||
| if (tensors[index]->device_address() != nullptr) { | |||
| AnfAlgo::SetOutputAddr(nullptr, index, value_node.get()); | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "CleanValueNodeDeviceAddr end"; | |||
| } | |||
| void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||
| @@ -67,8 +67,6 @@ class KernelRuntime { | |||
| const AddressPtrList &kernel_workspaces) const; | |||
| virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); | |||
| virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); | |||
| virtual void SyncValueNodeDeviceAddr(session::KernelGraph *graph); | |||
| virtual void CleanValueNodeDeviceAddr(session::KernelGraph *graph); | |||
| virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | |||
| const std::unordered_set<ValueNodePtr> &value_nodes, | |||
| const std::vector<CNodePtr> &execution_order); | |||