Browse Source

!6121 avoid momory ascend for multi model.train or model.eval

Merge pull request !6121 from anzhengqi/avoid-memory-ascend-lenet
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
84240b5a5c
2 changed files with 20 additions and 6 deletions
  1. +2
    -0
      mindspore/ccsrc/minddata/dataset/callback/callback_manager.h
  2. +18
    -6
      mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc

+ 2
- 0
mindspore/ccsrc/minddata/dataset/callback/callback_manager.h View File

@@ -70,6 +70,8 @@ class CallbackManager {
/// \return Status
Status StepEnd(const CallbackParam &);

bool HasCallback() { return !callbacks_.empty(); }

private:
bool enabled_; // flag to enable callback, if false, all functions would return immediately
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager


+ 18
- 6
mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc View File

@@ -164,7 +164,9 @@ Status MapOp::operator()() {
// Create and register the local queues.
local_queues_.Init(num_workers_, oc_queue_size_);
// init callback
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
if (callback_manager_.HasCallback()) {
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
}
Status rc = local_queues_.Register(tree_->AllTasks());
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
if (rc.IsError()) {
@@ -179,20 +181,26 @@ Status MapOp::operator()() {
RETURN_IF_NOT_OK(rc);
// num_buffers received, including eoe, num_epoch, num_step of current epoch
int64_t num_buf = 0, ep_step = 0, total_step = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
if (callback_manager_.HasCallback()) {
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
}

std::unique_ptr<DataBuffer> buff;

RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
while (!buff->eof()) {
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
if (callback_manager_.HasCallback()) {
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
}
while (!buff->eoe()) {
ep_step++;
total_step++;
// Create an empty map worker job to be populated by a databuffer and map jobs
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
if (callback_manager_.HasCallback()) {
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));

// Populate map worker job for a worker to execute
@@ -200,14 +208,18 @@ Status MapOp::operator()() {

// Push map worker job to the corresponding worker's queue
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
if (callback_manager_.HasCallback()) {
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}

RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}

// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
if (callback_manager_.HasCallback()) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
ep_step = 0;
}
// Propagate the eoe buffer to worker


Loading…
Cancel
Save