| @@ -91,13 +91,14 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| // If repeat but the not last repeat, wait for reset. | // If repeat but the not last repeat, wait for reset. | ||||
| if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (!IsLastIteration()) { | |||||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; | MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; | ||||
| RETURN_IF_NOT_OK(epoch_sync_.Wait()); | RETURN_IF_NOT_OK(epoch_sync_.Wait()); | ||||
| } else { | } else { | ||||
| // We can break out from the loop. | // We can break out from the loop. | ||||
| break; | break; | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } while (true); | } while (true); | ||||
| // Flow the eof before exit | // Flow the eof before exit | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -294,7 +294,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { | |||||
| Status CacheMergeOp::EoeReceived(int32_t worker_id) { | Status CacheMergeOp::EoeReceived(int32_t worker_id) { | ||||
| // If we are in a repeat path, send the eoe up. | // If we are in a repeat path, send the eoe up. | ||||
| // Otherwise ignore it. | // Otherwise ignore it. | ||||
| if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||||
| if (op_total_repeats_ > 1) { | |||||
| return DatasetOp::EoeReceived(worker_id); | return DatasetOp::EoeReceived(worker_id); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -306,7 +306,7 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) { | |||||
| // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. | // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. | ||||
| // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class | // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class | ||||
| // provides that for us. | // provides that for us. | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||||
| if (op_total_repeats_ == 1) { | |||||
| MS_LOG(DEBUG) << "Cache merge sending eoe"; | MS_LOG(DEBUG) << "Cache merge sending eoe"; | ||||
| RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); | RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); | ||||
| } | } | ||||
| @@ -85,6 +85,10 @@ Status CacheOp::operator()() { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| // Wait for the workers to finish caching the rows. | // Wait for the workers to finish caching the rows. | ||||
| RETURN_IF_NOT_OK(WaitForCachingAllRows()); | RETURN_IF_NOT_OK(WaitForCachingAllRows()); | ||||
| // Current repeats and current epochs may have increased when caching all rows with DatasetOp::GetNextInput. | |||||
| // But they shouldn't be increased because now cache op is starting to act as a leaf and its epoch hasn't started. | |||||
| op_current_repeats_ = 0; | |||||
| op_current_epochs_ = 0; | |||||
| RETURN_IF_NOT_OK(FetchSamplesToWorkers()); | RETURN_IF_NOT_OK(FetchSamplesToWorkers()); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -87,6 +87,7 @@ Status ConcatOp::operator()() { | |||||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, | CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, | ||||
| "Something went wrong, eof count does not match the number of children."); | "Something went wrong, eof count does not match the number of children."); | ||||
| @@ -42,7 +42,10 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler | |||||
| operator_id_(kInvalidOperatorId), | operator_id_(kInvalidOperatorId), | ||||
| tree_(nullptr), | tree_(nullptr), | ||||
| state_(OpState::kDeOpIdle), | state_(OpState::kDeOpIdle), | ||||
| op_ctrl_flags_(kDeOpNone), | |||||
| op_total_repeats_(kInfiniteRepeat), | |||||
| op_num_repeats_per_epoch_(kInfiniteRepeat), | |||||
| op_current_repeats_(0), | |||||
| op_current_epochs_(0), | |||||
| out_connector_(nullptr) { | out_connector_(nullptr) { | ||||
| // The operator starts out with an invalid operator id. The only way to | // The operator starts out with an invalid operator id. The only way to | ||||
| // get it out of invalid state is to assign the operator to an execution tree. | // get it out of invalid state is to assign the operator to an execution tree. | ||||
| @@ -234,8 +237,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { | |||||
| for (size_t i = 0; i < parent_.size(); i++) { | for (size_t i = 0; i < parent_.size(); i++) { | ||||
| out << "\n Parent[" << i << "] id: " << parent_[i]->id(); | out << "\n Parent[" << i << "] id: " << parent_[i]->id(); | ||||
| } | } | ||||
| out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex | |||||
| << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); | |||||
| out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ | |||||
| << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; | |||||
| if (sampler_) { | if (sampler_) { | ||||
| sampler_->Print(out, show_all); | sampler_->Print(out, show_all); | ||||
| } | } | ||||
| @@ -264,6 +267,7 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||||
| RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); | RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); | ||||
| // Loop until non EOE is received | // Loop until non EOE is received | ||||
| while (buf->eoe()) { | while (buf->eoe()) { | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | RETURN_IF_NOT_OK(EoeReceived(worker_id)); | ||||
| if (state_ == OpState::kDeOpIdle) { | if (state_ == OpState::kDeOpIdle) { | ||||
| *p_buffer = std::move(buf); | *p_buffer = std::move(buf); | ||||
| @@ -407,5 +411,10 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||||
| uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); | uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); | ||||
| return cache_crc; | return cache_crc; | ||||
| } | } | ||||
| void DatasetOp::UpdateRepeatAndEpochCounter() { | |||||
| op_current_repeats_++; | |||||
| if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,13 +70,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| public: | public: | ||||
| static constexpr int32_t kInvalidOperatorId = -1; | static constexpr int32_t kInvalidOperatorId = -1; | ||||
| // Operator control flags | |||||
| enum OpControlFlags { | |||||
| kDeOpNone = 0, | |||||
| kDeOpRepeated = 1, // Operator is a node in a repeat path | |||||
| kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop | |||||
| }; | |||||
| static constexpr int32_t kInfiniteRepeat = -1; | |||||
| // Flags that control operator runtime behaviours | // Flags that control operator runtime behaviours | ||||
| enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; | enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; | ||||
| @@ -238,13 +232,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \return T/F if this is an inlined operator | /// \return T/F if this is an inlined operator | ||||
| bool inlined() const { return (oc_queue_size_ == 0); } | bool inlined() const { return (oc_queue_size_ == 0); } | ||||
| /// \brief Setter function | |||||
| /// \return Sets the control flags | |||||
| void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } | |||||
| /// \brief Setter function, set the number of total repeats for the operator | |||||
| void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } | |||||
| /// \brief Setter function, set the number of repeats per epoch for the operator | |||||
| void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; } | |||||
| /// \brief Setter function | |||||
| /// \return Sets the control flags | |||||
| void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } | |||||
| /// \brief Getter function | |||||
| /// \return The number of required repeats for the operator | |||||
| int32_t op_total_repeats() { return op_total_repeats_; } | |||||
| /// \brief Getter function | |||||
| /// \return The number of required epochs for the operator | |||||
| int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; } | |||||
| /// \brief Getter function | |||||
| /// \return The number of repeats per epoch for the operator | |||||
| int32_t op_num_repeats_per_epoch() { return op_num_repeats_per_epoch_; } | |||||
| /// \brief Register the internal worker connectors. No op unless it is a parallel op | /// \brief Register the internal worker connectors. No op unless it is a parallel op | ||||
| /// \return Status | /// \return Status | ||||
| @@ -350,6 +354,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \return boolean returns true if it's a leaf | /// \return boolean returns true if it's a leaf | ||||
| bool IsLeaf() { return (child_.empty()); } | bool IsLeaf() { return (child_.empty()); } | ||||
| /// Checks if an operator has reached its last iteration | |||||
| /// \return boolean returns true if it's last iteration | |||||
| bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } | |||||
| protected: | protected: | ||||
| /// \brief Removes a parent operator from this operator | /// \brief Removes a parent operator from this operator | ||||
| /// \notes External callers do not have access to this function | /// \notes External callers do not have access to this function | ||||
| @@ -368,6 +376,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \return - Status | /// \return - Status | ||||
| virtual Status ComputeColMap(); | virtual Status ComputeColMap(); | ||||
| /// Increase op_current_repeats_ by 1 when one repeat finished. | |||||
| /// If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1. | |||||
| void UpdateRepeatAndEpochCounter(); | |||||
| std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | ||||
| std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | ||||
| std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler | std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler | ||||
| @@ -375,7 +387,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| int32_t operator_id_; // Generated id for the node | int32_t operator_id_; // Generated id for the node | ||||
| ExecutionTree *tree_; // Back pointer to our tree. | ExecutionTree *tree_; // Back pointer to our tree. | ||||
| OpState state_; // The state of the operator, Running, Idle, Terminated | OpState state_; // The state of the operator, Running, Idle, Terminated | ||||
| uint32_t op_ctrl_flags_; // Flags for the operator | |||||
| int32_t op_total_repeats_; // Required number of repeats for the operator | |||||
| int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator | |||||
| int32_t op_current_repeats_; // Current number of repeats the operator has handled | |||||
| int32_t op_current_epochs_; // Current number of epochs the operator has handled | |||||
| std::unique_ptr<DbConnector> out_connector_; // Output Connector | std::unique_ptr<DbConnector> out_connector_; // Output Connector | ||||
| std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name | std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name | ||||
| std::mutex column_name_map_mutex_; // For protecting shared access to the column map | std::mutex column_name_map_mutex_; // For protecting shared access to the column map | ||||
| @@ -30,7 +30,7 @@ namespace dataset { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) { | Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) { | ||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| *ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_); | |||||
| *ptr = std::make_shared<EpochCtrlOp>(build_num_repeats_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -48,12 +48,12 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common 1-liner info | // Call the super class for displaying any common 1-liner info | ||||
| PipelineOp::Print(out, show_all); | PipelineOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal 1-liner info for this op | // Then show any custom derived-internal 1-liner info for this op | ||||
| out << " [epochs: " << max_repeats_ << "]\n"; | |||||
| out << " [epochs: " << num_repeats_ << "]\n"; | |||||
| } else { | } else { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| PipelineOp::Print(out, show_all); | PipelineOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_ | |||||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ | |||||
| << "\nLeaf Nodes in execution path:"; | << "\nLeaf Nodes in execution path:"; | ||||
| if (!eoe_ops_.empty()) { | if (!eoe_ops_.empty()) { | ||||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | for (size_t i = 0; i < eoe_ops_.size(); i++) { | ||||
| @@ -88,24 +88,15 @@ Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t | |||||
| } | } | ||||
| Status EpochCtrlOp::EoeReceived(int32_t worker_id) { | Status EpochCtrlOp::EoeReceived(int32_t worker_id) { | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| repeat_count_++; | repeat_count_++; | ||||
| MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_ | MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_ | ||||
| << ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_; | |||||
| // If we've reached the requested epoch count, then flag the leaf nodes | |||||
| // to tell them they've got one more epoch to perform. When they reach the end | |||||
| // of the last epoch, they quit rather than loop again. | |||||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) { | |||||
| for (auto &eoe_op : eoe_ops_) { | |||||
| MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id(); | |||||
| eoe_op->set_control_flag(kDeOpLastRepeat); | |||||
| } | |||||
| } | |||||
| << ". Max epochs: " << num_repeats_; | |||||
| // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | ||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| if (repeat_count_ != max_repeats_) { | |||||
| if (repeat_count_ != num_repeats_) { | |||||
| for (auto &eoe_op : eoe_ops_) { | for (auto &eoe_op : eoe_ops_) { | ||||
| MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); | MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); | ||||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | RETURN_IF_NOT_OK(eoe_op->Reset()); | ||||
| @@ -119,6 +119,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); | RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); | ||||
| if (in_buffer->eoe()) { | if (in_buffer->eoe()) { | ||||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); | filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| continue; | continue; | ||||
| } else if (in_buffer->eof()) { | } else if (in_buffer->eof()) { | ||||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); | filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); | ||||
| @@ -233,6 +233,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) { | |||||
| // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work | // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work | ||||
| // with Performance Mode design. | // with Performance Mode design. | ||||
| if (in_buffer->eoe()) { | if (in_buffer->eoe()) { | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| // Calling base class EoeReceived to forward eoe buffer. | // Calling base class EoeReceived to forward eoe buffer. | ||||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | RETURN_IF_NOT_OK(EoeReceived(worker_id)); | ||||
| // Fetch next data buffer and map job list | // Fetch next data buffer and map job list | ||||
| @@ -76,6 +76,9 @@ Status ProjectOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t w | |||||
| if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { | if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { | ||||
| RETURN_IF_NOT_OK(Project(p_buffer)); | RETURN_IF_NOT_OK(Project(p_buffer)); | ||||
| } | } | ||||
| if ((*p_buffer)->eoe()) { | |||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -28,10 +28,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Builder constructor. Creates the builder object. | // Builder constructor. Creates the builder object. | ||||
| RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} | |||||
| RepeatOp::Builder::Builder(int32_t count) : build_num_repeats_(count) {} | |||||
| Status RepeatOp::Builder::SanityCheck() const { | Status RepeatOp::Builder::SanityCheck() const { | ||||
| if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { | |||||
| if (build_num_repeats_ < kInfiniteRepeat || build_num_repeats_ == 0) { | |||||
| std::string err_msg("Repeat count must be > 0 or -1."); | std::string err_msg("Repeat count must be > 0 or -1."); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| @@ -41,12 +41,12 @@ Status RepeatOp::Builder::SanityCheck() const { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| Status RepeatOp::Builder::Build(std::shared_ptr<RepeatOp> *ptr) { | Status RepeatOp::Builder::Build(std::shared_ptr<RepeatOp> *ptr) { | ||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| *ptr = std::make_shared<RepeatOp>(build_max_repeats_); | |||||
| *ptr = std::make_shared<RepeatOp>(build_num_repeats_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Constructor of the RepeatOp. | // Constructor of the RepeatOp. | ||||
| RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} | |||||
| RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), num_repeats_(count), repeat_count_(0) {} | |||||
| // Destructor | // Destructor | ||||
| RepeatOp::~RepeatOp() {} | RepeatOp::~RepeatOp() {} | ||||
| @@ -59,12 +59,12 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common 1-liner info | // Call the super class for displaying any common 1-liner info | ||||
| PipelineOp::Print(out, show_all); | PipelineOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal 1-liner info for this op | // Then show any custom derived-internal 1-liner info for this op | ||||
| out << " [repeats: " << max_repeats_ << "]\n"; | |||||
| out << " [repeats: " << num_repeats_ << "]\n"; | |||||
| } else { | } else { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| PipelineOp::Print(out, show_all); | PipelineOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ | |||||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ | |||||
| << "\nLeaf Nodes in execution path:"; | << "\nLeaf Nodes in execution path:"; | ||||
| if (!eoe_ops_.empty()) { | if (!eoe_ops_.empty()) { | ||||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | for (size_t i = 0; i < eoe_ops_.size(); i++) { | ||||
| @@ -109,22 +109,13 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||||
| // Base-class override for handling cases when an eoe is received. | // Base-class override for handling cases when an eoe is received. | ||||
| Status RepeatOp::EoeReceived(int32_t worker_id) { | Status RepeatOp::EoeReceived(int32_t worker_id) { | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| repeat_count_++; | repeat_count_++; | ||||
| MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ | MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ | ||||
| << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | ||||
| bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); | |||||
| bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||||
| // If we've reached the requested repeat count, then flag the eoe nodes | |||||
| // to tell them they've got one more epoch to perform. When they reach the end | |||||
| // of the last epoch, they quit rather than loop again. This happens in two cases: | |||||
| // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, | |||||
| // 2- We are not repeated | |||||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { | |||||
| for (auto &eoe_op : eoe_ops_) { | |||||
| eoe_op->set_control_flag(kDeOpLastRepeat); | |||||
| } | |||||
| } | |||||
| if (repeat_count_ == max_repeats_) { | |||||
| if (repeat_count_ == num_repeats_) { | |||||
| repeat_count_ = 0; | repeat_count_ = 0; | ||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -26,8 +26,6 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| class RepeatOp : public PipelineOp { | class RepeatOp : public PipelineOp { | ||||
| public: | public: | ||||
| static constexpr int32_t kInfiniteRepeat = -1; | |||||
| // The nested builder class inside of the RepeatOp is used to help manage all of the arguments | // The nested builder class inside of the RepeatOp is used to help manage all of the arguments | ||||
| // for constructing it. This repeat op is very simple though, so this builder is really just | // for constructing it. This repeat op is very simple though, so this builder is really just | ||||
| // provided for a consistent look and feel for creators of Dataset operators overall. | // provided for a consistent look and feel for creators of Dataset operators overall. | ||||
| @@ -47,7 +45,7 @@ class RepeatOp : public PipelineOp { | |||||
| Status Build(std::shared_ptr<RepeatOp> *); | Status Build(std::shared_ptr<RepeatOp> *); | ||||
| protected: | protected: | ||||
| int32_t build_max_repeats_; | |||||
| int32_t build_num_repeats_; | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| }; | }; | ||||
| @@ -131,13 +129,24 @@ class RepeatOp : public PipelineOp { | |||||
| // @return Name of the current Op | // @return Name of the current Op | ||||
| std::string Name() const override { return kRepeatOp; } | std::string Name() const override { return kRepeatOp; } | ||||
| /// \brief Getter function | |||||
| /// \return The number of repeats that the user requested | |||||
| int32_t num_repeats() { return num_repeats_; } | |||||
| // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | ||||
| // \param[in] eoe_op The input leaf/eoe operator to add to the list | // \param[in] eoe_op The input leaf/eoe operator to add to the list | ||||
| void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | ||||
| protected: | protected: | ||||
| int32_t max_repeats_; // The number of repeats that the user requested | |||||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||||
| // The number of repeats that the user requested. | |||||
| // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. | |||||
| // For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4), | |||||
| // num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6. | |||||
| int32_t num_repeats_; | |||||
| // A counter for the current number of executed repeats. | |||||
| // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class | |||||
| // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. | |||||
| int32_t repeat_count_; | |||||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -293,7 +293,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -310,6 +310,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -120,7 +120,7 @@ Status CifarOp::operator()() { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -137,6 +137,7 @@ Status CifarOp::operator()() { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -271,13 +271,14 @@ Status ClueOp::operator()() { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | finished_reading_dataset_ = true; | ||||
| NotifyToFillIOBlockQueue(); | NotifyToFillIOBlockQueue(); | ||||
| } else { | } else { | ||||
| jagged_buffer_connector_->DoReset(); | jagged_buffer_connector_->DoReset(); | ||||
| buffer_id = 0; | buffer_id = 0; | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | ||||
| @@ -167,7 +167,7 @@ Status CocoOp::operator()() { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | ||||
| std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof); | std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof); | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); | ||||
| @@ -184,6 +184,7 @@ Status CocoOp::operator()() { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -472,13 +472,14 @@ Status CsvOp::operator()() { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | finished_reading_dataset_ = true; | ||||
| NotifyToFillIOBlockQueue(); | NotifyToFillIOBlockQueue(); | ||||
| } else { | } else { | ||||
| jagged_buffer_connector_->DoReset(); | jagged_buffer_connector_->DoReset(); | ||||
| buffer_id = 0; | buffer_id = 0; | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | ||||
| @@ -218,7 +218,7 @@ Status GeneratorOp::operator()() { | |||||
| MS_LOG(DEBUG) << "Generator operator sends out EOE."; | MS_LOG(DEBUG) << "Generator operator sends out EOE."; | ||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| // If last repeat or not repeated, push out EOF and exit master loop | // If last repeat or not repeated, push out EOF and exit master loop | ||||
| MS_LOG(DEBUG) << "Generator operator sends out EOF."; | MS_LOG(DEBUG) << "Generator operator sends out EOF."; | ||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | ||||
| @@ -233,6 +233,7 @@ Status GeneratorOp::operator()() { | |||||
| // Clear the status of the wait post | // Clear the status of the wait post | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -151,7 +151,7 @@ Status ImageFolderOp::operator()() { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(keys, IOBlock::kDeIoBlockNone))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(keys, IOBlock::kDeIoBlockNone))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | ||||
| std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof); | std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof); | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); | ||||
| @@ -168,6 +168,7 @@ Status ImageFolderOp::operator()() { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -112,7 +112,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -129,6 +129,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -380,7 +380,7 @@ Status MindRecordOp::operator()() { | |||||
| RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( | RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -398,6 +398,7 @@ Status MindRecordOp::operator()() { | |||||
| RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); | RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); | ||||
| shard_reader_wait_post_.Clear(); | shard_reader_wait_post_.Clear(); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -111,7 +111,7 @@ Status MnistOp::operator()() { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -128,6 +128,7 @@ Status MnistOp::operator()() { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -221,7 +221,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { | |||||
| all_out_.Wait(); | all_out_.Wait(); | ||||
| // If we are not in a repeat loop, or that was the last repeat already, then setup our exit | // If we are not in a repeat loop, or that was the last repeat already, then setup our exit | ||||
| // condition from the master loop. | // condition from the master loop. | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| *quitting = true; | *quitting = true; | ||||
| } | } | ||||
| @@ -231,6 +231,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { | |||||
| if (last_guy_in) { | if (last_guy_in) { | ||||
| MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " | MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " | ||||
| << eoe_worker_id_; | << eoe_worker_id_; | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| // Prepare for sync | // Prepare for sync | ||||
| all_out_.Clear(); | all_out_.Clear(); | ||||
| // Always flow eoe at the end | // Always flow eoe at the end | ||||
| @@ -421,13 +421,14 @@ Status TextFileOp::operator()() { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | finished_reading_dataset_ = true; | ||||
| NotifyToFillIOBlockQueue(); | NotifyToFillIOBlockQueue(); | ||||
| } else { | } else { | ||||
| jagged_buffer_connector_->DoReset(); | jagged_buffer_connector_->DoReset(); | ||||
| buffer_id = 0; | buffer_id = 0; | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | ||||
| @@ -310,13 +310,14 @@ Status TFReaderOp::operator()() { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | finished_reading_dataset_ = true; | ||||
| NotifyToFillIOBlockQueue(); | NotifyToFillIOBlockQueue(); | ||||
| } else { | } else { | ||||
| jagged_buffer_connector_->DoReset(); | jagged_buffer_connector_->DoReset(); | ||||
| buffer_id = 0; | buffer_id = 0; | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | ||||
| @@ -145,7 +145,7 @@ Status VOCOp::operator()() { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | ||||
| } | } | ||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| if (IsLastIteration()) { | |||||
| std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | ||||
| std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof); | std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof); | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); | ||||
| @@ -162,6 +162,7 @@ Status VOCOp::operator()() { | |||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -84,6 +84,7 @@ Status TakeOp::operator()() { | |||||
| // Loop until non EOE is received | // Loop until non EOE is received | ||||
| if (buf->eoe()) { | if (buf->eoe()) { | ||||
| UpdateRepeatAndEpochCounter(); | |||||
| take_count_ = 0; | take_count_ = 0; | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); | RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); | ||||
| @@ -25,18 +25,44 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} | |||||
| RepeatPass::RepeatPass() | |||||
| : is_repeated_(false), | |||||
| nested_repeats_(0), | |||||
| num_repeats_(1), | |||||
| num_epochs_(1), | |||||
| is_merge_(false), | |||||
| is_cached_(false), | |||||
| cache_lookup_(nullptr) {} | |||||
| // Identifies the subtree below this node as being in a repeated path of the tree. | // Identifies the subtree below this node as being in a repeated path of the tree. | ||||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | ||||
| // Create a new stack for eoe operators and push onto our stack of stacks. | // Create a new stack for eoe operators and push onto our stack of stacks. | ||||
| std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>(); | |||||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||||
| eoe_op_stacks_.push(std::move(new_stack)); | eoe_op_stacks_.push(std::move(new_stack)); | ||||
| // If we are already repeated, then this is a nested repeat. | // If we are already repeated, then this is a nested repeat. | ||||
| if (is_repeated_) { | if (is_repeated_) { | ||||
| nested_repeats_++; | nested_repeats_++; | ||||
| } | } | ||||
| is_repeated_ = true; | is_repeated_ = true; | ||||
| // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | |||||
| // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | |||||
| if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | |||||
| num_repeats_ = -num_repeats_; | |||||
| } | |||||
| // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times. | |||||
| // | |||||
| // Consider this example: | |||||
| // tfreader --> map --> repeat(2) --> epoch ctrl(3) | |||||
| // num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3), | |||||
| // meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op. | |||||
| // | |||||
| // Another example: | |||||
| // tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4) | |||||
| // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4), | |||||
| // meaning repeat2 and map op should be set to read 8 times (2*4). | |||||
| // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times. | |||||
| num_repeats_ *= node->num_repeats(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -46,9 +72,16 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie | |||||
| // that RepeatOp does. However, epoch control is actually simpler because it can | // that RepeatOp does. However, epoch control is actually simpler because it can | ||||
| // only exist as the root node so it doesn't need all the nested code. | // only exist as the root node so it doesn't need all the nested code. | ||||
| // Create a new stack for eoe operators and push onto our stack of stacks. | // Create a new stack for eoe operators and push onto our stack of stacks. | ||||
| std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>(); | |||||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||||
| eoe_op_stacks_.push(std::move(new_stack)); | eoe_op_stacks_.push(std::move(new_stack)); | ||||
| is_repeated_ = true; | is_repeated_ = true; | ||||
| // Get the total number of epochs from the EpochCtrlOp parameter | |||||
| num_epochs_ = node->num_repeats(); | |||||
| // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | |||||
| // For example: tfreader --> epoch ctrl(3) | |||||
| // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3), | |||||
| // meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op. | |||||
| num_repeats_ *= num_epochs_; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -59,6 +92,13 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Identifies the subtree below this node as being cached | |||||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| // Turn on the flag that we're under a merge op | |||||
| is_cached_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Hooks up any identified eoe nodes under this repeat. | // Hooks up any identified eoe nodes under this repeat. | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | ||||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | ||||
| @@ -71,7 +111,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | ||||
| // at this time, so we can pop it to get rid of it. | // at this time, so we can pop it to get rid of it. | ||||
| eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| if (!current_stack->empty()) { | if (!current_stack->empty()) { | ||||
| RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | ||||
| } | } | ||||
| @@ -82,14 +122,14 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| // from the save area, because the merge op above us may also take action on it later for a different | // from the save area, because the merge op above us may also take action on it later for a different | ||||
| // case when there is no repeat in the merge leg. | // case when there is no repeat in the merge leg. | ||||
| if (is_merge_ && cache_lookup_) { | if (is_merge_ && cache_lookup_) { | ||||
| cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); | |||||
| cache_lookup_->set_total_repeats(num_repeats_); | |||||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| node->AddToEoeList(std::move(cache_lookup_)); | node->AddToEoeList(std::move(cache_lookup_)); | ||||
| } | } | ||||
| // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | ||||
| // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | ||||
| if (nested_repeats_ > 0) { | if (nested_repeats_ > 0) { | ||||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||||
| AddToEOEOpStack(node); | AddToEOEOpStack(node); | ||||
| nested_repeats_--; | nested_repeats_--; | ||||
| } else { | } else { | ||||
| @@ -99,7 +139,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| } | } | ||||
| is_repeated_ = false; | is_repeated_ = false; | ||||
| } | } | ||||
| if (is_cached_) { | |||||
| AddToCachedOpStack(node); | |||||
| } | |||||
| node->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| // We finish the walk of this RepeatOp's descendent nodes. | |||||
| // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. | |||||
| // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, | |||||
| // so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. | |||||
| num_repeats_ /= node->num_repeats(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -112,13 +161,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) | |||||
| leaf_op = PopFromEOEOpStack(); | leaf_op = PopFromEOEOpStack(); | ||||
| } | } | ||||
| is_repeated_ = false; | is_repeated_ = false; | ||||
| node->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| // We finish the walk of this EpochCtrl's descendent nodes. | |||||
| num_repeats_ /= node->num_repeats(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // CacheOp removes previous leaf ops and replaces them with itself | // CacheOp removes previous leaf ops and replaces them with itself | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | ||||
| is_cached_ = false; | |||||
| if (is_repeated_) { | if (is_repeated_) { | ||||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||||
| // if we are a cache within a repeat path of the tree, then there will be | // if we are a cache within a repeat path of the tree, then there will be | ||||
| // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the | // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the | ||||
| // repeat or epoch ctrl operators can work with them for repeat activity during runtime. | // repeat or epoch ctrl operators can work with them for repeat activity during runtime. | ||||
| @@ -130,13 +183,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| // the repeating behaviours shall be invoked against the cache op. | // the repeating behaviours shall be invoked against the cache op. | ||||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | ||||
| while (leaf_op != nullptr) { | while (leaf_op != nullptr) { | ||||
| leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); | |||||
| leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); | |||||
| leaf_op = PopFromEOEOpStack(); | leaf_op = PopFromEOEOpStack(); | ||||
| } | } | ||||
| AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node)); | AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node)); | ||||
| // adjust the total epochs and total repeats for ops under this cache op | |||||
| std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack(); | |||||
| while (cached_op != nullptr) { | |||||
| int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; | |||||
| cached_op->set_total_repeats(cached_op_total_repeats); | |||||
| // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 | |||||
| cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); | |||||
| cached_op = PopFromCachedOpStack(); | |||||
| } | |||||
| } | } | ||||
| node->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -145,13 +208,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | ||||
| // If we are in a repeat path, then set our repeated flag | // If we are in a repeat path, then set our repeated flag | ||||
| if (is_repeated_) { | if (is_repeated_) { | ||||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||||
| // if we are a leaf node then save ourself in a stack for the repeat operator above us | // if we are a leaf node then save ourself in a stack for the repeat operator above us | ||||
| if (node->IsLeaf()) { | if (node->IsLeaf()) { | ||||
| AddToEOEOpStack(node); | AddToEOEOpStack(node); | ||||
| } | } | ||||
| } | } | ||||
| if (is_cached_) { | |||||
| AddToCachedOpStack(node); | |||||
| } | |||||
| // Set total repeats and total epochs for the node | |||||
| node->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -159,13 +226,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | ||||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | // Setting the flag is needed since we didn't call the base class DatasetOp version | ||||
| if (is_repeated_) { | if (is_repeated_) { | ||||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | // If there was not any repeat in the merge cache miss leg, then the cache_lookup | ||||
| // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | ||||
| if (cache_lookup_) { | if (cache_lookup_) { | ||||
| cache_lookup_->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| AddToEOEOpStack(std::move(cache_lookup_)); | AddToEOEOpStack(std::move(cache_lookup_)); | ||||
| } | } | ||||
| } | } | ||||
| node->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used | cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used | ||||
| is_merge_ = false; | is_merge_ = false; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -178,13 +248,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified | |||||
| RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); | RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); | ||||
| } | } | ||||
| // If we are in a repeat path already, then there must be a repeat above the merge op | |||||
| // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. | |||||
| if (is_repeated_) { | |||||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||||
| // Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that. | |||||
| } | |||||
| // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we | // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we | ||||
| // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself | // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself | ||||
| // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. | // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. | ||||
| @@ -197,19 +260,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified | |||||
| // Adds an operator to the eoe operator stack save area | // Adds an operator to the eoe operator stack save area | ||||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | ||||
| eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| current_stack->push(dataset_op); | current_stack->push(dataset_op); | ||||
| } | } | ||||
| // Pops an operator from the eoe operator stack save area | // Pops an operator from the eoe operator stack save area | ||||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | ||||
| std::shared_ptr<DatasetOp> top_op = nullptr; | std::shared_ptr<DatasetOp> top_op = nullptr; | ||||
| eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| if (current_stack != nullptr && !current_stack->empty()) { | if (current_stack != nullptr && !current_stack->empty()) { | ||||
| top_op = current_stack->top(); | top_op = current_stack->top(); | ||||
| current_stack->pop(); | current_stack->pop(); | ||||
| } | } | ||||
| return top_op; | return top_op; | ||||
| } | } | ||||
| // Adds an operator to the cached operator stack save area | |||||
| void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | |||||
| // Pops an operator from the cached operator stack save area | |||||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromCachedOpStack() { | |||||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||||
| if (!cached_op_stacks_.empty()) { | |||||
| top_op = cached_op_stacks_.top(); | |||||
| cached_op_stacks_.pop(); | |||||
| } | |||||
| return top_op; | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,7 +30,7 @@ namespace dataset { | |||||
| /// to the eoe-producing (typically leaf) nodes underneath it. | /// to the eoe-producing (typically leaf) nodes underneath it. | ||||
| class RepeatPass : public NodePass { | class RepeatPass : public NodePass { | ||||
| public: | public: | ||||
| using eoe_op_stack = std::stack<std::shared_ptr<DatasetOp>>; | |||||
| using op_stack = std::stack<std::shared_ptr<DatasetOp>>; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RepeatPass(); | RepeatPass(); | ||||
| @@ -56,6 +56,12 @@ class RepeatPass : public NodePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | ||||
| /// \brief Identifies the subtree below this node as being cached | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Hooks up any identified eoe nodes under this repeat. | /// \brief Hooks up any identified eoe nodes under this repeat. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| @@ -103,11 +109,24 @@ class RepeatPass : public NodePass { | |||||
| /// \return shared_ptr to the popped operator | /// \return shared_ptr to the popped operator | ||||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | ||||
| bool is_repeated_; // T/F if we are processing under a repeat | |||||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||||
| int32_t nested_repeats_; // A counter for nested repeats | |||||
| std::stack<std::unique_ptr<eoe_op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||||
| /// \brief Adds an operator to the cached operator stack save area | |||||
| /// \param op - The dataset op to work add to cached stack | |||||
| /// \return Status - The error code return | |||||
| void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||||
| /// \brief Pops an operator from the cached operator stack save area | |||||
| /// \return shared_ptr to the popped operator | |||||
| std::shared_ptr<DatasetOp> PopFromCachedOpStack(); | |||||
| bool is_repeated_; // T/F if we are processing under a repeat | |||||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||||
| bool is_cached_; // T/F is we are processing under a cache op | |||||
| int32_t nested_repeats_; // A counter for nested repeats | |||||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||||
| int32_t num_epochs_; // To save the total number of epochs | |||||
| std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||||
| op_stack cached_op_stacks_; // A save area for ops under a cache op | |||||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -565,6 +565,99 @@ def test_generator_tuple_repeat_repeat_3(): | |||||
| # rely on garbage collector to destroy iter1 | # rely on garbage collector to destroy iter1 | ||||
| def test_generator_tuple_infinite_repeat_repeat_1(): | |||||
| """ | |||||
| test generator tuple infinite repeat repeat 1 | |||||
| """ | |||||
| logger.info("Test 1D Generator : 0 - 63") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||||
| data1 = data1.repeat() | |||||
| data1 = data1.repeat(3) | |||||
| iter1 = data1.create_tuple_iterator(num_epochs=11) | |||||
| i = 0 | |||||
| for item in iter1: # each data is a dictionary | |||||
| golden = np.array([i % 64]) | |||||
| np.testing.assert_array_equal(item[0], golden) | |||||
| i = i + 1 | |||||
| if i == 100: | |||||
| break | |||||
| # rely on garbage collector to destroy iter1 | |||||
| def test_generator_tuple_infinite_repeat_repeat_2(): | |||||
| """ | |||||
| test generator tuple infinite repeat repeat 2 | |||||
| """ | |||||
| logger.info("Test 1D Generator : 0 - 63") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||||
| data1 = data1.repeat(3) | |||||
| data1 = data1.repeat() | |||||
| iter1 = data1.create_tuple_iterator(num_epochs=11) | |||||
| i = 0 | |||||
| for item in iter1: # each data is a dictionary | |||||
| golden = np.array([i % 64]) | |||||
| np.testing.assert_array_equal(item[0], golden) | |||||
| i = i + 1 | |||||
| if i == 100: | |||||
| break | |||||
| # rely on garbage collector to destroy iter1 | |||||
| def test_generator_tuple_infinite_repeat_repeat_3(): | |||||
| """ | |||||
| test generator tuple infinite repeat repeat 3 | |||||
| """ | |||||
| logger.info("Test 1D Generator : 0 - 63") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||||
| data1 = data1.repeat() | |||||
| data1 = data1.repeat() | |||||
| iter1 = data1.create_tuple_iterator(num_epochs=11) | |||||
| i = 0 | |||||
| for item in iter1: # each data is a dictionary | |||||
| golden = np.array([i % 64]) | |||||
| np.testing.assert_array_equal(item[0], golden) | |||||
| i = i + 1 | |||||
| if i == 100: | |||||
| break | |||||
| # rely on garbage collector to destroy iter1 | |||||
| def test_generator_tuple_infinite_repeat_repeat_4(): | |||||
| """ | |||||
| test generator tuple infinite repeat repeat 4 | |||||
| """ | |||||
| logger.info("Test 1D Generator : 0 - 63") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||||
| data1 = data1.repeat() | |||||
| data1 = data1.repeat() | |||||
| iter1 = data1.create_tuple_iterator() | |||||
| i = 0 | |||||
| for item in iter1: # each data is a dictionary | |||||
| golden = np.array([i % 64]) | |||||
| np.testing.assert_array_equal(item[0], golden) | |||||
| i = i + 1 | |||||
| if i == 100: | |||||
| break | |||||
| # rely on garbage collector to destroy iter1 | |||||
| def test_generator_reusedataset(): | def test_generator_reusedataset(): | ||||
| """ | """ | ||||
| test generator reusedataset | test generator reusedataset | ||||