| @@ -293,10 +293,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(kernel_graph); | |||
| auto &input_nodes = kernel_graph->input_nodes(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| if (device::KernelRuntime::use_mem_scheduler()) { | |||
| kernel_graph->SetInputTensors(inputs); | |||
| return; | |||
| } | |||
| @@ -342,6 +339,8 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| tensor->data_c(), tensor->device_info().host_format_)) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || | |||
| AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) { | |||
| tensor->set_device_address(device_address); | |||
| @@ -539,8 +538,7 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { | |||
| } else { | |||
| // alloc memory, including static memory and dynamic memory | |||
| MemoryAlloc(graph.get()); | |||
| auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (!enable_mem_scheduler) { | |||
| if (!device::KernelRuntime::use_mem_scheduler()) { | |||
| AnfAlgo::CacheAddrForGraph(graph); | |||
| } | |||
| // generate and load task info to device if it is sink mode | |||
| @@ -577,8 +575,7 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||
| // optimize graph | |||
| HardwareOptimize(child_graph); | |||
| // assign static memory of parameters | |||
| auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (!enable_mem_scheduler) { | |||
| if (!device::KernelRuntime::use_mem_scheduler()) { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->AssignStaticMemoryInput(*child_graph); | |||
| @@ -1801,10 +1798,7 @@ void AscendSession::ExecuteAllTaskInQueue() { | |||
| void AscendSession::UpdateOutputTensors(const VectorRef *outputs, | |||
| const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, | |||
| std::map<DeviceAddressPtr, DeviceAddressPtr> *) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| if (device::KernelRuntime::use_mem_scheduler()) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| @@ -160,7 +160,8 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) { | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, | |||
| KernelMapTensor *) { | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node); | |||
| @@ -34,7 +34,8 @@ class CPUSession : public SessionBasic { | |||
| protected: | |||
| void UnifyMindIR(const KernelGraphPtr &graph) override { SessionBasic::UnifyMindIR(graph); } | |||
| void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override; | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, | |||
| KernelMapTensor *node_to_tensor) override; | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *const outputs) override; | |||
| @@ -130,6 +130,9 @@ void RunGraphTask::Run() { | |||
| return; | |||
| } | |||
| graph->ResetGraphRunningStatus(); | |||
| if (device::KernelRuntime::use_mem_scheduler()) { | |||
| graph->SetOutputNodeToTensor(node_to_tensor_); | |||
| } | |||
| try { | |||
| session_->LoadInputs(graph_id_, input_tensors_); | |||
| session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | |||
| @@ -361,7 +364,7 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, | |||
| task->session_ = session; | |||
| task->graph_id_ = graph_id; | |||
| task->input_tensors_ = inputs; | |||
| session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); | |||
| session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_); | |||
| task->outputs_ = *outputs; | |||
| task->sync_run_ = true; | |||
| RunTask(task, true, true); | |||
| @@ -383,7 +386,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||
| reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); }); | |||
| MsException::Instance().CheckException(); | |||
| } | |||
| session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); | |||
| session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_); | |||
| // maintain a copy of output vector | |||
| task->outputs_ = *outputs; | |||
| @@ -97,6 +97,7 @@ class RunGraphTask : public Task { | |||
| VectorRef outputs_; | |||
| GraphId graph_id_{0}; | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; | |||
| KernelMapTensor node_to_tensor_; | |||
| }; | |||
| class RunOpsInGraphTask : public Task { | |||
| @@ -1679,24 +1679,20 @@ std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const Graph | |||
| void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) { | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, | |||
| KernelMapTensor *node_to_tensor) { | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| MS_EXCEPTION_IF_NULL(tensor_to_node); | |||
| auto anf_outputs = kernel_graph->outputs(); | |||
| KernelMapTensor node_to_tensor; | |||
| for (auto &item : anf_outputs) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]"; | |||
| outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor)); | |||
| outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, node_to_tensor)); | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| kernel_graph->SetOutputNodeToTensor(node_to_tensor); | |||
| } | |||
| } | |||
| void SessionBasic::UpdateOutputTensors(const VectorRef *outputs, | |||
| @@ -1704,8 +1700,7 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs, | |||
| std::map<DeviceAddressPtr, DeviceAddressPtr> *) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| if (device::KernelRuntime::use_mem_scheduler()) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| @@ -211,7 +211,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual bool IsSupportSummary() { return true; } | |||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, | |||
| KernelMapTensor *node_to_tensor); | |||
| // When the device address of the node is used as the output of the graph, the device address will be passed | |||
| // to the output tensor, and the output node will recreate a new device address. This third parameter records | |||
| // the relationship between the new and old device address. | |||
| @@ -98,8 +98,7 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_ | |||
| void KernelRuntime::AssignMemory(const session::KernelGraph &graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| if (use_mem_scheduler()) { | |||
| AssignStaticMemoryValueNode(graph); | |||
| ResetNodeAddress(graph); | |||
| } else { | |||
| @@ -1175,6 +1174,17 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod | |||
| } | |||
| } | |||
| bool KernelRuntime::use_mem_scheduler() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER)) { | |||
| return false; | |||
| } | |||
| // Not use MemScheduler when running single op | |||
| return (!context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) && | |||
| (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode)); | |||
| } | |||
| void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs, | |||
| const std::shared_ptr<MemScheduler> &mem_scheduler) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -1347,28 +1357,29 @@ void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &m | |||
| } | |||
| continue; | |||
| } | |||
| if (tensor != nullptr) { | |||
| if (device_address == nullptr) { | |||
| tensor->data_sync(false); | |||
| tensor->set_device_address(nullptr); | |||
| tensor->set_sync_status(kNeedSyncHostToDevice); | |||
| continue; | |||
| } | |||
| if (!SyncStream()) { | |||
| MS_LOG(ERROR) << "SyncStream failed"; | |||
| } | |||
| auto origin_ptr = device_address->ptr_; | |||
| if (origin_ptr == nullptr) { | |||
| device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_); | |||
| } | |||
| tensor->set_device_address(device_address); | |||
| if (tensor == nullptr) { | |||
| continue; | |||
| } | |||
| if (device_address == nullptr) { | |||
| tensor->data_sync(false); | |||
| tensor->set_device_address(nullptr); | |||
| if (origin_ptr == nullptr) { | |||
| device_address->ptr_ = nullptr; | |||
| } | |||
| tensor->set_sync_status(kNeedSyncHostToDevice); | |||
| continue; | |||
| } | |||
| if (!SyncStream()) { | |||
| MS_LOG(EXCEPTION) << "SyncStream failed"; | |||
| } | |||
| auto origin_ptr = device_address->ptr_; | |||
| if (origin_ptr == nullptr) { | |||
| device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_); | |||
| } | |||
| tensor->set_device_address(device_address); | |||
| tensor->data_sync(false); | |||
| tensor->set_device_address(nullptr); | |||
| if (origin_ptr == nullptr) { | |||
| device_address->ptr_ = nullptr; | |||
| } | |||
| tensor->set_sync_status(kNeedSyncHostToDevice); | |||
| } | |||
| } | |||
| @@ -1384,21 +1395,24 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m | |||
| auto tensor = input_tensors[i]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto input_node = input_nodes[i]; | |||
| if (!input_node->isa<Parameter>()) { | |||
| if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| MemPriority priority = kMemPriorityHigh; | |||
| auto tensor_address = tensor->device_address(); | |||
| if (tensor_address != nullptr && tensor_address != device_address) { | |||
| tensor->data_sync(false); | |||
| priority = kMemPriorityLow; | |||
| } | |||
| auto tensor_size = LongToSize(tensor->data().nbytes()); | |||
| mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| MemPriority priority = kMemPriorityLow; | |||
| auto tensor_address = tensor->device_address(); | |||
| if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) { | |||
| tensor->data_sync(false); | |||
| } | |||
| if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) || | |||
| graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) { | |||
| tensor->set_device_address(device_address); | |||
| priority = kMemPriorityHigh; | |||
| } | |||
| auto tensor_size = LongToSize(tensor->data().nbytes()); | |||
| mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority); | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| } | |||
| } | |||
| @@ -1451,8 +1465,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::shared_ptr<MemScheduler> mem_scheduler = nullptr; | |||
| auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| if (use_mem_scheduler()) { | |||
| mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); | |||
| MS_EXCEPTION_IF_NULL(mem_scheduler); | |||
| mem_scheduler->SetMemHandler(mem_manager_); | |||
| @@ -1520,28 +1533,28 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock | |||
| void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER); | |||
| if (enable_mem_scheduler) { | |||
| auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); | |||
| if (mem_scheduler->need_record_event()) { | |||
| (void)LaunchKernelMod(graph, true); | |||
| mem_scheduler->set_need_record_event(false); | |||
| } | |||
| float mem_used_factor = kMaxMemReuseFactor; | |||
| while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) { | |||
| mem_scheduler->SetMemUsedFactor(mem_used_factor); | |||
| mem_scheduler->OptMemUsage(); | |||
| bool ret = LaunchKernelMod(graph, true); | |||
| if (ret) { | |||
| mem_scheduler->set_optimized(true); | |||
| } else { | |||
| mem_used_factor -= kRetryFactor; | |||
| } | |||
| } | |||
| if (!mem_scheduler->optimized()) { | |||
| MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit."; | |||
| if (!use_mem_scheduler()) { | |||
| return; | |||
| } | |||
| auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); | |||
| if (mem_scheduler->need_record_event()) { | |||
| (void)LaunchKernelMod(graph, true); | |||
| mem_scheduler->set_need_record_event(false); | |||
| } | |||
| float mem_used_factor = kMaxMemReuseFactor; | |||
| while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) { | |||
| mem_scheduler->SetMemUsedFactor(mem_used_factor); | |||
| mem_scheduler->OptMemUsage(); | |||
| bool ret = LaunchKernelMod(graph, true); | |||
| if (ret) { | |||
| mem_scheduler->set_optimized(true); | |||
| } else { | |||
| mem_used_factor -= kRetryFactor; | |||
| } | |||
| } | |||
| if (!mem_scheduler->optimized()) { | |||
| MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit."; | |||
| } | |||
| } | |||
| bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) { | |||
| @@ -94,6 +94,7 @@ class KernelRuntime { | |||
| virtual void ReleaseDeviceRes() {} | |||
| void set_device_id(uint32_t device_id) { device_id_ = device_id; } | |||
| uint32_t device_id() { return device_id_; } | |||
| static bool use_mem_scheduler(); | |||
| #ifdef ENABLE_DEBUGGER | |||
| // set debugger | |||
| @@ -202,6 +202,9 @@ void MemScheduler::CountMemUsage() { | |||
| if (!min_mem_used_.empty()) { | |||
| return; | |||
| } | |||
| if (mem_events_.empty() || compute_index_ == 0) { | |||
| return; | |||
| } | |||
| min_mem_used_.resize(compute_index_, 0); | |||
| std::vector<size_t> total_mem_used(compute_index_, 0); | |||
| for (auto &item : mem_events_) { | |||