| @@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor:: | |||||
| #endif | #endif | ||||
| { | { | ||||
| // run task on device | // run task on device | ||||
| Execute(kernel_graph); | |||||
| Execute(kernel_graph, true); | |||||
| } | } | ||||
| // summary | // summary | ||||
| Summary(kernel_graph.get()); | Summary(kernel_graph.get()); | ||||
| @@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG | |||||
| MS_LOG(INFO) << "Finish"; | MS_LOG(INFO) << "Finish"; | ||||
| } | } | ||||
| void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| MS_LOG(INFO) << "Start!"; | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||||
| bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get()); | |||||
| if (!ret_ok) { | |||||
| MS_LOG(EXCEPTION) << "Run task error!"; | |||||
| } | |||||
| MS_LOG(INFO) << "Finish!"; | |||||
| } | |||||
| bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { | bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { | ||||
| return run_op_graphs_.find(graph_info) != run_op_graphs_.end(); | return run_op_graphs_.find(graph_info) != run_op_graphs_.end(); | ||||
| } | } | ||||
| @@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i | |||||
| // load input data to device | // load input data to device | ||||
| LoadInputData(graph, input_tensors); | LoadInputData(graph, input_tensors); | ||||
| // run op | // run op | ||||
| RunOpExecTask(graph); | |||||
| Execute(graph, false); | |||||
| // get output | // get output | ||||
| if (op_run_info.value != nullptr) { | if (op_run_info.value != nullptr) { | ||||
| std::vector<tensor::TensorPtr> pre_output_tensors; | std::vector<tensor::TensorPtr> pre_output_tensors; | ||||
| @@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const { | |||||
| void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||||
| (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); | (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); | ||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| bool ret_ok = runtime_instance->Load(kernel_graph.get()); | |||||
| bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink); | |||||
| if (!ret_ok) { | if (!ret_ok) { | ||||
| MS_LOG(EXCEPTION) << "Load task error!"; | MS_LOG(EXCEPTION) << "Load task error!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const { | |||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| bool is_task_sink = false; | |||||
| if (is_task) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||||
| } | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| bool ret_ok = runtime_instance->Run(kernel_graph.get()); | |||||
| bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink); | |||||
| if (!ret_ok) { | if (!ret_ok) { | ||||
| MS_LOG(EXCEPTION) << "run task error!"; | MS_LOG(EXCEPTION) << "run task error!"; | ||||
| } | } | ||||
| @@ -13,8 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H | #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H | ||||
| #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H | #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -82,13 +84,12 @@ class AscendSession : public SessionBasic { | |||||
| KernelGraph *kernel_graph) const; | KernelGraph *kernel_graph) const; | ||||
| void RunOpMemoryClear(const KernelGraph *kernel_graph) const; | void RunOpMemoryClear(const KernelGraph *kernel_graph) const; | ||||
| void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const; | |||||
| void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs); | void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs); | ||||
| void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| // below functions are used for run op | // below functions are used for run op | ||||
| void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const; | void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const; | ||||
| void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | ||||
| static void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | static void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | ||||
| @@ -118,7 +118,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||||
| debugger_->PreExecute(kernel_graph); | debugger_->PreExecute(kernel_graph); | ||||
| } | } | ||||
| #endif | #endif | ||||
| bool ret = runtime_.Run(kernel_graph.get()); | |||||
| bool ret = runtime_.Run(kernel_graph.get(), false); | |||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(EXCEPTION) << "Run graph failed"; | MS_LOG(EXCEPTION) << "Run graph failed"; | ||||
| } | } | ||||
| @@ -191,9 +191,9 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) { | |||||
| if (!runtime_instance->Run(kernel_graph.get(), false, debugger_.get())) { | |||||
| #else | #else | ||||
| if (!runtime_instance->Run(kernel_graph.get())) { | |||||
| if (!runtime_instance->Run(kernel_graph.get(), false)) { | |||||
| #endif | #endif | ||||
| MS_LOG(EXCEPTION) << "GPU execute graph failed!"; | MS_LOG(EXCEPTION) << "GPU execute graph failed!"; | ||||
| } | } | ||||
| @@ -454,10 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size | |||||
| return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); | return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); | ||||
| } | } | ||||
| bool AscendKernelRuntime::Load(session::KernelGraph *graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||||
| bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { | |||||
| if (!is_task_sink) { | if (!is_task_sink) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -609,17 +606,14 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { | |||||
| } | } | ||||
| } | } | ||||
| bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { | |||||
| bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) { | |||||
| bool ret = false; | bool ret = false; | ||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| auto start_time = std::chrono::steady_clock::now(); | auto start_time = std::chrono::steady_clock::now(); | ||||
| #else | #else | ||||
| struct timeval start_time, end_time; | struct timeval start_time, end_time; | ||||
| (void)gettimeofday(&start_time, nullptr); | (void)gettimeofday(&start_time, nullptr); | ||||
| #endif | #endif | ||||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||||
| if (is_task_sink) { | if (is_task_sink) { | ||||
| ret = RunTask(graph); | ret = RunTask(graph); | ||||
| } else { | } else { | ||||
| @@ -44,8 +44,8 @@ class AscendKernelRuntime : public KernelRuntime { | |||||
| bool GenTask(const session::KernelGraph *graph); | bool GenTask(const session::KernelGraph *graph); | ||||
| bool LoadTask(const session::KernelGraph *graph); | bool LoadTask(const session::KernelGraph *graph); | ||||
| bool RunTask(const session::KernelGraph *graph); | bool RunTask(const session::KernelGraph *graph); | ||||
| bool Load(session::KernelGraph *graph) override; | |||||
| bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||||
| bool Load(session::KernelGraph *graph, bool is_task_sink) override; | |||||
| bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override; | |||||
| void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | ||||
| const std::unordered_set<ValueNodePtr> &value_nodes, | const std::unordered_set<ValueNodePtr> &value_nodes, | ||||
| const std::vector<CNodePtr> &execution_order) override; | const std::vector<CNodePtr> &execution_order) override; | ||||
| @@ -287,7 +287,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput | |||||
| resource_manager_.DecreaseSummaryRefCount(summary_outputs); | resource_manager_.DecreaseSummaryRefCount(summary_outputs); | ||||
| } | } | ||||
| bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) { | |||||
| bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, bool is_task_sink, Debugger *debugger) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| resource_manager_.IncreaseAddressRefCount(kernel_graph); | resource_manager_.IncreaseAddressRefCount(kernel_graph); | ||||
| @@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime { | |||||
| ~CPUKernelRuntime() override = default; | ~CPUKernelRuntime() override = default; | ||||
| bool Init() override { return true; } | bool Init() override { return true; } | ||||
| bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||||
| bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override; | |||||
| void AssignKernelAddress(session::KernelGraph *kernel_graph); | void AssignKernelAddress(session::KernelGraph *kernel_graph); | ||||
| void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs); | VectorRef *outputs); | ||||
| @@ -1,112 +1,112 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ | |||||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include "runtime/device/kernel_runtime.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | |||||
| #include "backend/optimizer/mem_reuse/mem_swap_manager.h" | |||||
| namespace mindspore { | |||||
| namespace device { | |||||
| namespace gpu { | |||||
| using mindspore::device::memswap::MemSwapManagerPtr; | |||||
| class GPUKernelRuntime : public KernelRuntime { | |||||
| public: | |||||
| GPUKernelRuntime() = default; | |||||
| ~GPUKernelRuntime() override = default; | |||||
| bool Init() override; | |||||
| void ReleaseDeviceRes() override; | |||||
| void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | |||||
| const std::unordered_set<ValueNodePtr> &value_nodes, | |||||
| const std::vector<CNodePtr> &execution_order) override; | |||||
| void AssignMemory(session::KernelGraph *graph) override; | |||||
| bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||||
| #ifdef ENABLE_DUMP_E2E | |||||
| bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||||
| #endif | |||||
| protected: | |||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) override; | |||||
| bool SyncStream() override; | |||||
| private: | |||||
| GPUKernelRuntime(const GPUKernelRuntime &); | |||||
| GPUKernelRuntime &operator=(const GPUKernelRuntime &); | |||||
| bool InitDevice(); | |||||
| bool device_init_{false}; | |||||
| // The related functions and members for using dynamic memory pool. | |||||
| void InitKernelRefCount(const session::KernelGraph *graph); | |||||
| void InitKernelOutputAddress(const session::KernelGraph *graph); | |||||
| void InitKernelWorkspaceAddress(const session::KernelGraph *graph); | |||||
| void InitMemorySwapInfo(const session::KernelGraph *graph); | |||||
| void SaveGraphOutputNode(const session::KernelGraph *graph); | |||||
| bool IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const; | |||||
| void ClearKernelOutputAddress(const session::KernelGraph *graph); | |||||
| void ClearKernelWorkspaceAddress(const session::KernelGraph *graph); | |||||
| void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph); | |||||
| bool RunOneStep(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||||
| bool SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||||
| bool RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger = nullptr, bool mock = false, | |||||
| bool profiling = false); | |||||
| void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||||
| const AddressPtrList &workspace, const AddressPtrList &outputs); | |||||
| bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); | |||||
| bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | |||||
| AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, | |||||
| AddressPtrList *kernel_outputs, bool mock); | |||||
| bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, bool mock); | |||||
| bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | |||||
| AddressPtrList *kernel_outputs, bool mock); | |||||
| bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, | |||||
| const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces, | |||||
| bool mock); | |||||
| void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); | |||||
| void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, | |||||
| const DeviceAddressPtrList addr_list, size_t total_size, | |||||
| std::vector<size_t> size_list); | |||||
| void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| bool UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling); | |||||
| bool AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling); | |||||
| void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); | |||||
| void UpdateHostSwapOutQueue(bool mock); | |||||
| void ClearSwapInfo(bool mock); | |||||
| std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | |||||
| std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | |||||
| std::unordered_map<uint32_t, bool> is_first_step_map_; | |||||
| std::unordered_map<uint32_t, std::set<AnfNodePtr>> graph_output_map_; | |||||
| MemReuseUtilPtr mem_reuse_util_{nullptr}; | |||||
| MemSwapManagerPtr mem_swap_manager_{nullptr}; | |||||
| }; | |||||
| MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); | |||||
| } // namespace gpu | |||||
| } // namespace device | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ | |||||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include "runtime/device/kernel_runtime.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | |||||
| #include "backend/optimizer/mem_reuse/mem_swap_manager.h" | |||||
| namespace mindspore { | |||||
| namespace device { | |||||
| namespace gpu { | |||||
| using mindspore::device::memswap::MemSwapManagerPtr; | |||||
| class GPUKernelRuntime : public KernelRuntime { | |||||
| public: | |||||
| GPUKernelRuntime() = default; | |||||
| ~GPUKernelRuntime() override = default; | |||||
| bool Init() override; | |||||
| void ReleaseDeviceRes() override; | |||||
| void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | |||||
| const std::unordered_set<ValueNodePtr> &value_nodes, | |||||
| const std::vector<CNodePtr> &execution_order) override; | |||||
| void AssignMemory(session::KernelGraph *graph) override; | |||||
| bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override; | |||||
| #ifdef ENABLE_DUMP_E2E | |||||
| bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||||
| #endif | |||||
| protected: | |||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) override; | |||||
| bool SyncStream() override; | |||||
| private: | |||||
| GPUKernelRuntime(const GPUKernelRuntime &); | |||||
| GPUKernelRuntime &operator=(const GPUKernelRuntime &); | |||||
| bool InitDevice(); | |||||
| bool device_init_{false}; | |||||
| // The related functions and members for using dynamic memory pool. | |||||
| void InitKernelRefCount(const session::KernelGraph *graph); | |||||
| void InitKernelOutputAddress(const session::KernelGraph *graph); | |||||
| void InitKernelWorkspaceAddress(const session::KernelGraph *graph); | |||||
| void InitMemorySwapInfo(const session::KernelGraph *graph); | |||||
| void SaveGraphOutputNode(const session::KernelGraph *graph); | |||||
| bool IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const; | |||||
| void ClearKernelOutputAddress(const session::KernelGraph *graph); | |||||
| void ClearKernelWorkspaceAddress(const session::KernelGraph *graph); | |||||
| void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph); | |||||
| bool RunOneStep(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||||
| bool SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||||
| bool RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger = nullptr, bool mock = false, | |||||
| bool profiling = false); | |||||
| void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||||
| const AddressPtrList &workspace, const AddressPtrList &outputs); | |||||
| bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); | |||||
| bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | |||||
| AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, | |||||
| AddressPtrList *kernel_outputs, bool mock); | |||||
| bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, bool mock); | |||||
| bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | |||||
| AddressPtrList *kernel_outputs, bool mock); | |||||
| bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, | |||||
| const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces, | |||||
| bool mock); | |||||
| void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); | |||||
| void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, | |||||
| const DeviceAddressPtrList addr_list, size_t total_size, | |||||
| std::vector<size_t> size_list); | |||||
| void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| bool UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling); | |||||
| bool AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling); | |||||
| void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); | |||||
| void UpdateHostSwapOutQueue(bool mock); | |||||
| void ClearSwapInfo(bool mock); | |||||
| std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | |||||
| std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | |||||
| std::unordered_map<uint32_t, bool> is_first_step_map_; | |||||
| std::unordered_map<uint32_t, std::set<AnfNodePtr>> graph_output_map_; | |||||
| MemReuseUtilPtr mem_reuse_util_{nullptr}; | |||||
| MemSwapManagerPtr mem_swap_manager_{nullptr}; | |||||
| }; | |||||
| MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); | |||||
| } // namespace gpu | |||||
| } // namespace device | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ | |||||
| @@ -40,7 +40,7 @@ KernelRuntime::~KernelRuntime() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| bool KernelRuntime::Load(session::KernelGraph *graph) { return true; } | |||||
| bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; } | |||||
| bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { | bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { | ||||
| if (graph != nullptr) { | if (graph != nullptr) { | ||||
| @@ -59,8 +59,8 @@ class KernelRuntime { | |||||
| bool DumpDataEnabled(); | bool DumpDataEnabled(); | ||||
| bool DumpDataEnabledIteration(); | bool DumpDataEnabledIteration(); | ||||
| virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr); | virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr); | ||||
| virtual bool Load(session::KernelGraph *graph); | |||||
| virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0; | |||||
| virtual bool Load(session::KernelGraph *graph, bool is_task_sink); | |||||
| virtual bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) = 0; | |||||
| bool LaunchKernel(const session::KernelGraph *graph); | bool LaunchKernel(const session::KernelGraph *graph); | ||||
| bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs, | bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs, | ||||
| const AddressPtrList &kernel_outputs, | const AddressPtrList &kernel_outputs, | ||||