| @@ -26,7 +26,7 @@ void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> call | |||||
| callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end()); | callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end()); | ||||
| } | } | ||||
| Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) { | |||||
| Status CallbackManager::Init(DatasetOp *op) { | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | RETURN_UNEXPECTED_IF_NULL(op); | ||||
| op_ = op; | op_ = op; | ||||
| // turn the flag on if callback is set | // turn the flag on if callback is set | ||||
| @@ -42,6 +42,7 @@ Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) { | |||||
| Status CallbackManager::Begin(const CallbackParam &cb_param) { | Status CallbackManager::Begin(const CallbackParam &cb_param) { | ||||
| RETURN_OK_IF_TRUE(!enabled_); | RETURN_OK_IF_TRUE(!enabled_); | ||||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||||
| std::vector<size_t> callback_inds; | std::vector<size_t> callback_inds; | ||||
| // go through all callback functions to see if each function is needed | // go through all callback functions to see if each function is needed | ||||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | for (size_t ind = 0; ind < callbacks_.size(); ind++) { | ||||
| @@ -61,6 +62,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) { | |||||
| Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | ||||
| RETURN_OK_IF_TRUE(!enabled_); | RETURN_OK_IF_TRUE(!enabled_); | ||||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||||
| std::vector<size_t> callback_inds; | std::vector<size_t> callback_inds; | ||||
| // go through all callback functions to see if each function is needed | // go through all callback functions to see if each function is needed | ||||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | for (size_t ind = 0; ind < callbacks_.size(); ind++) { | ||||
| @@ -80,6 +82,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | |||||
| Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | ||||
| RETURN_OK_IF_TRUE(!enabled_); | RETURN_OK_IF_TRUE(!enabled_); | ||||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||||
| std::vector<size_t> callback_inds; | std::vector<size_t> callback_inds; | ||||
| // go through all callback functions to see if each function is needed | // go through all callback functions to see if each function is needed | ||||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | for (size_t ind = 0; ind < callbacks_.size(); ind++) { | ||||
| @@ -100,6 +103,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | |||||
| Status CallbackManager::End(const CallbackParam &cb_param) { | Status CallbackManager::End(const CallbackParam &cb_param) { | ||||
| RETURN_OK_IF_TRUE(!enabled_); | RETURN_OK_IF_TRUE(!enabled_); | ||||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||||
| std::vector<size_t> callback_inds; | std::vector<size_t> callback_inds; | ||||
| // go through all callback functions to see if each function is needed | // go through all callback functions to see if each function is needed | ||||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | for (size_t ind = 0; ind < callbacks_.size(); ind++) { | ||||
| @@ -119,6 +123,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) { | |||||
| Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | ||||
| RETURN_OK_IF_TRUE(!enabled_); | RETURN_OK_IF_TRUE(!enabled_); | ||||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||||
| std::vector<size_t> callback_inds; | std::vector<size_t> callback_inds; | ||||
| // go through all callback functions to see if each function is needed | // go through all callback functions to see if each function is needed | ||||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | for (size_t ind = 0; ind < callbacks_.size(); ind++) { | ||||
| @@ -138,6 +143,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | |||||
| Status CallbackManager::StepEnd(const CallbackParam &cb_param) { | Status CallbackManager::StepEnd(const CallbackParam &cb_param) { | ||||
| RETURN_OK_IF_TRUE(!enabled_); | RETURN_OK_IF_TRUE(!enabled_); | ||||
| RETURN_UNEXPECTED_IF_NULL(op_); | |||||
| std::vector<size_t> callback_inds; | std::vector<size_t> callback_inds; | ||||
| // go through all callback functions to see if each function is needed | // go through all callback functions to see if each function is needed | ||||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | for (size_t ind = 0; ind < callbacks_.size(); ind++) { | ||||
| @@ -44,7 +44,7 @@ class CallbackManager { | |||||
| /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true | /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true | ||||
| /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads | /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads | ||||
| /// \return Status | /// \return Status | ||||
| Status Init(std::shared_ptr<DatasetOp> op); | |||||
| Status Init(DatasetOp *op); | |||||
| /// \brief callback function called at the start of the first row | /// \brief callback function called at the start of the first row | ||||
| /// \return Status | /// \return Status | ||||
| @@ -70,11 +70,9 @@ class CallbackManager { | |||||
| /// \return Status | /// \return Status | ||||
| Status StepEnd(const CallbackParam &); | Status StepEnd(const CallbackParam &); | ||||
| bool HasCallback() { return !callbacks_.empty(); } | |||||
| private: | private: | ||||
| bool enabled_; // flag to enable callback, if false, all functions would return immediately | |||||
| std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager | |||||
| bool enabled_; // flag to enable callback, if false, all functions would return immediately | |||||
| DatasetOp *op_; // back pointer to DatasetOp, raw pointer to avoid circular ownership | |||||
| std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call | std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -164,9 +164,7 @@ Status MapOp::operator()() { | |||||
| // Create and register the local queues. | // Create and register the local queues. | ||||
| local_queues_.Init(num_workers_, oc_queue_size_); | local_queues_.Init(num_workers_, oc_queue_size_); | ||||
| // init callback | // init callback | ||||
| if (callback_manager_.HasCallback()) { | |||||
| RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); | |||||
| } | |||||
| RETURN_IF_NOT_OK(callback_manager_.Init(this)); | |||||
| Status rc = local_queues_.Register(tree_->AllTasks()); | Status rc = local_queues_.Register(tree_->AllTasks()); | ||||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| @@ -181,26 +179,23 @@ Status MapOp::operator()() { | |||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| // num_buffers received, including eoe, num_epoch, num_step of current epoch | // num_buffers received, including eoe, num_epoch, num_step of current epoch | ||||
| int64_t num_buf = 0, ep_step = 0, total_step = 0; | int64_t num_buf = 0, ep_step = 0, total_step = 0; | ||||
| if (callback_manager_.HasCallback()) { | |||||
| RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); | |||||
| std::unique_ptr<DataBuffer> buff; | std::unique_ptr<DataBuffer> buff; | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | ||||
| while (!buff->eof()) { | while (!buff->eof()) { | ||||
| if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { | if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { | ||||
| if (callback_manager_.HasCallback()) { | |||||
| RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| } | } | ||||
| while (!buff->eoe()) { | while (!buff->eoe()) { | ||||
| ep_step++; | ep_step++; | ||||
| total_step++; | total_step++; | ||||
| // Create an empty map worker job to be populated by a databuffer and map jobs | // Create an empty map worker job to be populated by a databuffer and map jobs | ||||
| if (callback_manager_.HasCallback()) { | |||||
| RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | ||||
| // Populate map worker job for a worker to execute | // Populate map worker job for a worker to execute | ||||
| @@ -208,18 +203,16 @@ Status MapOp::operator()() { | |||||
| // Push map worker job to the corresponding worker's queue | // Push map worker job to the corresponding worker's queue | ||||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | ||||
| if (callback_manager_.HasCallback()) { | |||||
| RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | ||||
| } | } | ||||
| // check whether this is the end of a real epoch (not all eoe signals end of epoch) | // check whether this is the end of a real epoch (not all eoe signals end of epoch) | ||||
| if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { | if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { | ||||
| if (callback_manager_.HasCallback()) { | |||||
| RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||||
| ep_step = 0; | ep_step = 0; | ||||
| } | } | ||||
| // Propagate the eoe buffer to worker | // Propagate the eoe buffer to worker | ||||