Browse Source

load tensors for all graphs

tags/v1.1.0
yelihua 5 years ago
parent
commit
5026ef1d46
3 changed files with 14 additions and 5 deletions
  1. +2
    -1
      mindspore/ccsrc/debug/debugger/debugger.cc
  2. +5
    -0
      mindspore/ccsrc/debug/debugger/debugger.h
  3. +7
    -4
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc

+ 2
- 1
mindspore/ccsrc/debug/debugger/debugger.cc View File

@@ -291,7 +291,7 @@ void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
// only send compiled graphs once.
SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum);
graph_proto_list_.clear();
} else if (graph_id == rungraph_id_list_.front()) {
} else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
// stop only when receive the first sub run graph for each step
CommandLoop();
}
@@ -394,6 +394,7 @@ void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
auto graph_proto = GetGraphProto(graph_ptr);
// add new graph proto to graph_proto_list_
graph_proto_list_.push_back(graph_proto);
graph_ptr_list_.push_back(graph_ptr);
not_dataset_graph_sum_++;
}
// reset is_dataset_graph to be false


+ 5
- 0
mindspore/ccsrc/debug/debugger/debugger.h View File

@@ -118,6 +118,10 @@ class Debugger : public std::enable_shared_from_this<Debugger> {

uint32_t GetFirstRunGraphId();

void SetGraphPtr(const KernelGraphPtr &graph_ptr) { graph_ptr_ = graph_ptr; }

std::list<KernelGraphPtr> GetGraphPtrList() { return graph_ptr_list_; }

private:
// private constructor for singleton
Debugger();
@@ -204,6 +208,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// flag to keep track of the very first suspension of debugger
bool initial_suspend_;
std::list<GraphProto> graph_proto_list_;
std::list<KernelGraphPtr> graph_ptr_list_;

// singleton
static std::mutex instance_lock_;


+ 7
- 4
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc View File

@@ -287,10 +287,13 @@ bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph) {
MS_LOG(INFO) << "Start load step";
uint32_t cur_iter = 0;
MS_LOG(INFO) << "Cur iter is " << cur_iter;
// load output
debugger_->LoadGraphOutputs();
// load parameters
debugger_->LoadParametersAndConst();
for (auto graph_ptr : debugger_->GetGraphPtrList()) {
debugger_->SetGraphPtr(graph_ptr);
// load output
debugger_->LoadGraphOutputs();
// load parameters
debugger_->LoadParametersAndConst();
}
#endif
return true;
}


Loading…
Cancel
Save