| @@ -26,7 +26,7 @@ void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> call | |||
| callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end()); | |||
| } | |||
| Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) { | |||
| Status CallbackManager::Init(DatasetOp *op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| op_ = op; | |||
| // turn the flag on if callback is set | |||
| @@ -42,6 +42,7 @@ Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) { | |||
| Status CallbackManager::Begin(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| @@ -61,6 +62,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) { | |||
| Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| @@ -80,6 +82,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | |||
| Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| @@ -100,6 +103,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | |||
| Status CallbackManager::End(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| @@ -119,6 +123,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) { | |||
| Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| @@ -138,6 +143,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | |||
| Status CallbackManager::StepEnd(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| @@ -44,7 +44,7 @@ class CallbackManager { | |||
| /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true | |||
| /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads | |||
| /// \return Status | |||
| Status Init(std::shared_ptr<DatasetOp> op); | |||
| Status Init(DatasetOp *op); | |||
| /// \brief callback function called at the start of the first row | |||
| /// \return Status | |||
| @@ -70,11 +70,9 @@ class CallbackManager { | |||
| /// \return Status | |||
| Status StepEnd(const CallbackParam &); | |||
| bool HasCallback() { return !callbacks_.empty(); } | |||
| private: | |||
| bool enabled_; // flag to enable callback, if false, all functions would return immediately | |||
| std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager | |||
| bool enabled_; // flag to enable callback, if false, all functions would return immediately | |||
| DatasetOp *op_; // back pointer to DatasetOp, raw pointer to avoid circular ownership | |||
| std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call | |||
| }; | |||
| } // namespace dataset | |||
| @@ -164,9 +164,7 @@ Status MapOp::operator()() { | |||
| // Create and register the local queues. | |||
| local_queues_.Init(num_workers_, oc_queue_size_); | |||
| // init callback | |||
| if (callback_manager_.HasCallback()) { | |||
| RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); | |||
| } | |||
| RETURN_IF_NOT_OK(callback_manager_.Init(this)); | |||
| Status rc = local_queues_.Register(tree_->AllTasks()); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| if (rc.IsError()) { | |||
| @@ -181,26 +179,23 @@ Status MapOp::operator()() { | |||
| RETURN_IF_NOT_OK(rc); | |||
| // num_buffers received, including eoe, num_epoch, num_step of current epoch | |||
| int64_t num_buf = 0, ep_step = 0, total_step = 0; | |||
| if (callback_manager_.HasCallback()) { | |||
| RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); | |||
| } | |||
| RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); | |||
| std::unique_ptr<DataBuffer> buff; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| while (!buff->eof()) { | |||
| if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { | |||
| if (callback_manager_.HasCallback()) { | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| } | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| } | |||
| while (!buff->eoe()) { | |||
| ep_step++; | |||
| total_step++; | |||
| // Create an empty map worker job to be populated by a databuffer and map jobs | |||
| if (callback_manager_.HasCallback()) { | |||
| RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| } | |||
| RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | |||
| // Populate map worker job for a worker to execute | |||
| @@ -208,18 +203,16 @@ Status MapOp::operator()() { | |||
| // Push map worker job to the corresponding worker's queue | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | |||
| if (callback_manager_.HasCallback()) { | |||
| RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| } | |||
| RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| } | |||
| // check whether this is the end of a real epoch (not all eoe signals end of epoch) | |||
| if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { | |||
| if (callback_manager_.HasCallback()) { | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| } | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| ep_step = 0; | |||
| } | |||
| // Propagate the eoe buffer to worker | |||