| @@ -51,7 +51,15 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| PipelineOp::Print(out, show_all); | PipelineOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_; | |||||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ | |||||
| << "\nLeaf Nodes in execution path:"; | |||||
| if (!eoe_ops_.empty()) { | |||||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||||
| } | |||||
| } else { | |||||
| out << " None."; | |||||
| } | |||||
| out << "\n\n"; | out << "\n\n"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -86,6 +94,13 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) { | |||||
| // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | ||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| if (repeat_count_ != num_repeats_) { | |||||
| for (auto &eoe_op : eoe_ops_) { | |||||
| MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); | |||||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -62,7 +62,15 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| PipelineOp::Print(out, show_all); | PipelineOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_; | |||||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ | |||||
| << "\nLeaf Nodes in execution path:"; | |||||
| if (!eoe_ops_.empty()) { | |||||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||||
| } | |||||
| } else { | |||||
| out << " None."; | |||||
| } | |||||
| out << "\n\n"; | out << "\n\n"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -107,9 +115,17 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { | |||||
| if (repeat_count_ == num_repeats_) { | if (repeat_count_ == num_repeats_) { | ||||
| repeat_count_ = 0; | repeat_count_ = 0; | ||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| return Status::OK(); | |||||
| } else { | } else { | ||||
| state_ = OpState::kDeOpRunning; | state_ = OpState::kDeOpRunning; | ||||
| } | } | ||||
| // Invoke a reset against the eoe nodes only. | |||||
| for (auto &eoe_op : eoe_ops_) { | |||||
| MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); | |||||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -138,6 +154,19 @@ int32_t RepeatOp::num_consumers() const { | |||||
| } | } | ||||
| } | } | ||||
| // Drive reset actions if needed | |||||
| Status RepeatOp::Reset() { | |||||
| // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. | |||||
| // In that case, we now have to bounce the reset down to our own eoe ops. | |||||
| MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; | |||||
| for (auto &eoe_op : eoe_ops_) { | |||||
| MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); | |||||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||||
| } | |||||
| state_ = OpState::kDeOpRunning; | |||||
| return Status::OK(); | |||||
| } | |||||
| int32_t RepeatOp::num_producers() const { | int32_t RepeatOp::num_producers() const { | ||||
| if (child_.empty() || child_[0] == nullptr) { | if (child_.empty() || child_[0] == nullptr) { | ||||
| MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; | MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; | ||||
| @@ -129,6 +129,16 @@ class RepeatOp : public PipelineOp { | |||||
| /// \return The number of repeats that the user requested | /// \return The number of repeats that the user requested | ||||
| int32_t num_repeats() { return num_repeats_; } | int32_t num_repeats() { return num_repeats_; } | ||||
| /// \brief reset Op | |||||
| /// \@return Status - The error code return | |||||
| Status Reset() override; | |||||
| // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||||
| // \param[in] eoe_op The input leaf/eoe operator to add to the list | |||||
| void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | |||||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | |||||
| protected: | protected: | ||||
| // The number of repeats that the user requested. | // The number of repeats that the user requested. | ||||
| // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. | // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. | ||||
| @@ -186,6 +186,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { | |||||
| Status GeneratorOp::operator()() { | Status GeneratorOp::operator()() { | ||||
| // Handshake with TaskManager to synchronize thread creation | // Handshake with TaskManager to synchronize thread creation | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||||
| std::unique_ptr<DataBuffer> fetched_buffer; | std::unique_ptr<DataBuffer> fetched_buffer; | ||||
| bool eof = false; | bool eof = false; | ||||
| while (!eof) { | while (!eof) { | ||||
| @@ -227,8 +228,17 @@ Status GeneratorOp::operator()() { | |||||
| MS_LOG(DEBUG) << "Generator operator main execution loop complete."; | MS_LOG(DEBUG) << "Generator operator main execution loop complete."; | ||||
| eof = true; | eof = true; | ||||
| } else { | } else { | ||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| // Waiting for repeatOp to start new epoch | |||||
| // If Reset() is called first by repeat op, this wait() will return right away. | |||||
| // If Reset() is not called yet, this wait() will block until reset. | |||||
| if (this->op_total_repeats() < 0) { | |||||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||||
| // Clear the status of the wait post | |||||
| wp_.Clear(); | |||||
| } else { | |||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| } | |||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | UpdateRepeatAndEpochCounter(); | ||||
| } | } | ||||
| @@ -240,6 +250,10 @@ Status GeneratorOp::Reset() { | |||||
| // Reset Op state | // Reset Op state | ||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | MS_LOG(DEBUG) << Name() << " performing a self-reset."; | ||||
| RETURN_IF_NOT_OK(this->Init()); | RETURN_IF_NOT_OK(this->Init()); | ||||
| if (this->op_total_repeats() < 0) { | |||||
| // Wake up master thread | |||||
| wp_.Set(); | |||||
| } | |||||
| return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | ||||
| } | } | ||||
| @@ -144,6 +144,8 @@ class GeneratorOp : public PipelineOp { | |||||
| py::object generator_; | py::object generator_; | ||||
| int32_t buffer_id_; | int32_t buffer_id_; | ||||
| WaitPost wp_; | |||||
| Status Init(); | Status Init(); | ||||
| void Dealloc() noexcept; | void Dealloc() noexcept; | ||||
| @@ -22,15 +22,31 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | #include "minddata/dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RepeatPass::RepeatPass() | RepeatPass::RepeatPass() | ||||
| : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} | |||||
| : is_repeated_(false), | |||||
| nested_repeats_(0), | |||||
| num_repeats_(1), | |||||
| num_epochs_(1), | |||||
| is_merge_(false), | |||||
| is_cached_(false), | |||||
| cache_lookup_(nullptr) {} | |||||
| // Identifies the subtree below this node as being in a repeated path of the tree. | // Identifies the subtree below this node as being in a repeated path of the tree. | ||||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | ||||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||||
| eoe_op_stacks_.push(std::move(new_stack)); | |||||
| // If we are already repeated, then this is a nested repeat. | |||||
| if (is_repeated_) { | |||||
| nested_repeats_++; | |||||
| } | |||||
| is_repeated_ = true; | |||||
| // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | ||||
| // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | ||||
| if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | ||||
| @@ -58,7 +74,9 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie | |||||
| // that RepeatOp does. However, epoch control is actually simpler because it can | // that RepeatOp does. However, epoch control is actually simpler because it can | ||||
| // only exist as the root node so it doesn't need all the nested code. | // only exist as the root node so it doesn't need all the nested code. | ||||
| // Create a new stack for eoe operators and push onto our stack of stacks. | // Create a new stack for eoe operators and push onto our stack of stacks. | ||||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||||
| eoe_op_stacks_.push(std::move(new_stack)); | |||||
| is_repeated_ = true; | |||||
| // Get the total number of epochs from the EpochCtrlOp parameter | // Get the total number of epochs from the EpochCtrlOp parameter | ||||
| num_epochs_ = node->num_repeats(); | num_epochs_ = node->num_repeats(); | ||||
| // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | ||||
| @@ -85,6 +103,22 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| // Hooks up any identified eoe nodes under this repeat. | // Hooks up any identified eoe nodes under this repeat. | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | ||||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | |||||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||||
| while (leaf_op != nullptr) { | |||||
| node->AddToEoeList(leaf_op); | |||||
| leaf_op = PopFromEOEOpStack(); | |||||
| } | |||||
| // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | |||||
| // at this time, so we can pop it to get rid of it. | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| if (!current_stack->empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | |||||
| } | |||||
| eoe_op_stacks_.pop(); | |||||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | ||||
| // and set its total repeats. It is important that the op is removed from the save area, | // and set its total repeats. It is important that the op is removed from the save area, | ||||
| // because the merge op above us may also take action on it later for a different case when | // because the merge op above us may also take action on it later for a different case when | ||||
| @@ -95,6 +129,18 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| cache_lookup_.reset(); | cache_lookup_.reset(); | ||||
| } | } | ||||
| // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | |||||
| // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | |||||
| if (nested_repeats_ > 0) { | |||||
| AddToEOEOpStack(node); | |||||
| nested_repeats_--; | |||||
| } else { | |||||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||||
| if (nested_repeats_ != 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); | |||||
| } | |||||
| is_repeated_ = false; | |||||
| } | |||||
| if (is_cached_) { | if (is_cached_) { | ||||
| AddToCachedOpStack(node); | AddToCachedOpStack(node); | ||||
| } | } | ||||
| @@ -110,6 +156,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| // Hooks up any identified eoe nodes under this repeat. | // Hooks up any identified eoe nodes under this repeat. | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | ||||
| // Pop the leaf ops from the save-area stack and add them to the eoe node tracking | |||||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||||
| while (leaf_op != nullptr) { | |||||
| node->AddToEoeList(leaf_op); | |||||
| leaf_op = PopFromEOEOpStack(); | |||||
| } | |||||
| is_repeated_ = false; | |||||
| node->set_total_repeats(num_repeats_); | node->set_total_repeats(num_repeats_); | ||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | ||||
| // We finish the walk of this EpochCtrl's descendent nodes. | // We finish the walk of this EpochCtrl's descendent nodes. | ||||
| @@ -138,6 +191,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| // If we are in a repeat path, then set our repeated flag | |||||
| if (is_repeated_) { | |||||
| // if infinite repeat save ourself in a stack for the repeat operator above us | |||||
| if (num_repeats_ < 0) { | |||||
| AddToEOEOpStack(node); | |||||
| } | |||||
| } | |||||
| // If we are under a cache op, then save ourselves to the cached op stack. | |||||
| if (is_cached_) { | |||||
| AddToCachedOpStack(node); | |||||
| } | |||||
| // Set total repeats and total epochs for the node | |||||
| node->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| return Status::OK(); | |||||
| } | |||||
| // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | ||||
| // for use with a controlling repeat above it. | // for use with a controlling repeat above it. | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | ||||
| @@ -190,6 +260,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Adds an operator to the eoe operator stack save area | |||||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| current_stack->push(dataset_op); | |||||
| } | |||||
| // Pops an operator from the eoe operator stack save area | |||||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | |||||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| if (current_stack != nullptr && !current_stack->empty()) { | |||||
| top_op = current_stack->top(); | |||||
| current_stack->pop(); | |||||
| } | |||||
| return top_op; | |||||
| } | |||||
| // Adds an operator to the cached operator stack save area | // Adds an operator to the cached operator stack save area | ||||
| void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | ||||
| @@ -98,6 +98,12 @@ class RepeatPass : public NodePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | ||||
| /// \brief Special case for GeneratorOp | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||||
| /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | ||||
| /// for use with a controlling repeat above it. | /// for use with a controlling repeat above it. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| @@ -106,6 +112,19 @@ class RepeatPass : public NodePass { | |||||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | ||||
| private: | private: | ||||
| /// \brief Adds an operator to the eoe operator stack save area | |||||
| /// \param op - The dataset op to work add to eoe stack | |||||
| /// \return Status - The error code return | |||||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||||
| /// \brief Pops an operator from the eoe operator stack save area | |||||
| /// \return shared_ptr to the popped operator | |||||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||||
| bool is_repeated_; // T/F if we are processing under a repeat | |||||
| int32_t nested_repeats_; // A counter for nested repeats | |||||
| std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||||
| /// \brief Adds an operator to the cached operator stack save area | /// \brief Adds an operator to the cached operator stack save area | ||||
| /// \param op - The dataset op to work add to cached stack | /// \param op - The dataset op to work add to cached stack | ||||
| /// \return Status - The error code return | /// \return Status - The error code return | ||||