| @@ -291,7 +291,7 @@ void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) { | |||||
| // only send compiled graphs once. | // only send compiled graphs once. | ||||
| SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum); | SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum); | ||||
| graph_proto_list_.clear(); | graph_proto_list_.clear(); | ||||
| } else if (graph_id == rungraph_id_list_.front()) { | |||||
| } else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) { | |||||
| // stop only when receive the first sub run graph for each step | // stop only when receive the first sub run graph for each step | ||||
| CommandLoop(); | CommandLoop(); | ||||
| } | } | ||||
| @@ -394,6 +394,7 @@ void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) { | |||||
| auto graph_proto = GetGraphProto(graph_ptr); | auto graph_proto = GetGraphProto(graph_ptr); | ||||
| // add new graph proto to graph_proto_list_ | // add new graph proto to graph_proto_list_ | ||||
| graph_proto_list_.push_back(graph_proto); | graph_proto_list_.push_back(graph_proto); | ||||
| graph_ptr_list_.push_back(graph_ptr); | |||||
| not_dataset_graph_sum_++; | not_dataset_graph_sum_++; | ||||
| } | } | ||||
| // reset is_dataset_graph to be false | // reset is_dataset_graph to be false | ||||
| @@ -118,6 +118,10 @@ class Debugger : public std::enable_shared_from_this<Debugger> { | |||||
| uint32_t GetFirstRunGraphId(); | uint32_t GetFirstRunGraphId(); | ||||
| void SetGraphPtr(const KernelGraphPtr &graph_ptr) { graph_ptr_ = graph_ptr; } | |||||
| std::list<KernelGraphPtr> GetGraphPtrList() { return graph_ptr_list_; } | |||||
| private: | private: | ||||
| // private constructor for singleton | // private constructor for singleton | ||||
| Debugger(); | Debugger(); | ||||
| @@ -204,6 +208,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> { | |||||
| // flag to keep track of the very first suspension of debugger | // flag to keep track of the very first suspension of debugger | ||||
| bool initial_suspend_; | bool initial_suspend_; | ||||
| std::list<GraphProto> graph_proto_list_; | std::list<GraphProto> graph_proto_list_; | ||||
| std::list<KernelGraphPtr> graph_ptr_list_; | |||||
| // singleton | // singleton | ||||
| static std::mutex instance_lock_; | static std::mutex instance_lock_; | ||||
| @@ -287,10 +287,13 @@ bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph) { | |||||
| MS_LOG(INFO) << "Start load step"; | MS_LOG(INFO) << "Start load step"; | ||||
| uint32_t cur_iter = 0; | uint32_t cur_iter = 0; | ||||
| MS_LOG(INFO) << "Cur iter is " << cur_iter; | MS_LOG(INFO) << "Cur iter is " << cur_iter; | ||||
| // load output | |||||
| debugger_->LoadGraphOutputs(); | |||||
| // load parameters | |||||
| debugger_->LoadParametersAndConst(); | |||||
| for (auto graph_ptr : debugger_->GetGraphPtrList()) { | |||||
| debugger_->SetGraphPtr(graph_ptr); | |||||
| // load output | |||||
| debugger_->LoadGraphOutputs(); | |||||
| // load parameters | |||||
| debugger_->LoadParametersAndConst(); | |||||
| } | |||||
| #endif | #endif | ||||
| return true; | return true; | ||||
| } | } | ||||