From 21f25cbbb6378998dbb2cd143fd91428bce748ab Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Sat, 12 Sep 2020 17:28:03 +0800 Subject: [PATCH] avoid memory ascend for multi model.train or model.eval --- .../dataset/callback/callback_manager.h | 2 ++ .../engine/datasetops/map_op/map_op.cc | 24 ++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h index 0392a0d2cf..c0e65126a0 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h @@ -70,6 +70,8 @@ class CallbackManager { /// \return Status Status StepEnd(const CallbackParam &); + bool HasCallback() { return !callbacks_.empty(); } + private: bool enabled_; // flag to enable callback, if false, all functions would return immediately std::shared_ptr op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index 2556cc120b..eff69d7609 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -164,7 +164,9 @@ Status MapOp::operator()() { // Create and register the local queues. local_queues_.Init(num_workers_, oc_queue_size_); // init callback - RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); + if (callback_manager_.HasCallback()) { + RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); + } Status rc = local_queues_.Register(tree_->AllTasks()); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); if (rc.IsError()) { @@ -179,20 +181,26 @@ Status MapOp::operator()() { RETURN_IF_NOT_OK(rc); // num_buffers received, including eoe, num_epoch, num_step of current epoch int64_t num_buf = 0, ep_step = 0, total_step = 0; - RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); + if (callback_manager_.HasCallback()) { + RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); + } std::unique_ptr buff; RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); while (!buff->eof()) { if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { - RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + if (callback_manager_.HasCallback()) { + RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + } } while (!buff->eoe()) { ep_step++; total_step++; // Create an empty map worker job to be populated by a databuffer and map jobs - RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + if (callback_manager_.HasCallback()) { + RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + } std::unique_ptr worker_job = std::make_unique(std::move(buff)); // Populate map worker job for a worker to execute @@ -200,14 +208,18 @@ Status MapOp::operator()() { // Push map worker job to the corresponding worker's queue RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); - RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + if (callback_manager_.HasCallback()) { + RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + } RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); } // check whether this is the end of a real epoch (not all eoe signals end of epoch) if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { - RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + if (callback_manager_.HasCallback()) { + RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + } ep_step = 0; } // Propagate the eoe buffer to worker