Merge pull request !6598 from nsyca/revert-3974tags/v1.0.0
| @@ -218,6 +218,16 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba | |||
| return Status::OK(); | |||
| } | |||
| Status BucketBatchByLengthOp::Reset() { | |||
| batch_count_ = 0; | |||
| for (int i = 0; i < buckets_.size(); i++) { | |||
| buckets_[i] = std::make_unique<TensorQTable>(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Computing the assignment of the column name map and check compute input columns. | |||
| Status BucketBatchByLengthOp::ComputeColMap() { | |||
| RETURN_IF_NOT_OK(DatasetOp::ComputeColMap()); | |||
| @@ -126,6 +126,10 @@ class BucketBatchByLengthOp : public PipelineOp { | |||
| // @return Status - The error code returned | |||
| Status operator()() override; | |||
| // Function that is called by ResetOp at the end of every epoch | |||
| // @return Status - The error code returned | |||
| Status Reset() override; | |||
| private: | |||
| Status ObtainElementLength(int32_t *out_element_length, TensorRow element); | |||
| @@ -42,7 +42,8 @@ Status CacheBase::Reset() { | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| } | |||
| // Wake up the workers to get them going again in a new epoch | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| MS_LOG(DEBUG) << Name() << " resetting."; | |||
| epoch_sync_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| @@ -71,6 +72,7 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | |||
| // to the WorkerEntry. | |||
| do { | |||
| epoch_sync_.Clear(); | |||
| if (AllowCacheMiss() && wait_cnt > 0) { | |||
| MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ | |||
| << " Total number of rows : " << row_cnt_; | |||
| @@ -110,17 +112,11 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| // If repeat but the not last repeat, wait for reset. | |||
| if (!IsLastIteration()) { | |||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | |||
| RETURN_IF_NOT_OK(epoch_sync_.Wait()); | |||
| } else { | |||
| // We can break out from the loop. | |||
| break; | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset()); | |||
| UpdateRepeatAndEpochCounter(); | |||
| } while (true); | |||
| // Flow the eof before exit | |||
| @@ -146,13 +142,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> blk; | |||
| do { | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||
| if (blk->wait()) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (blk->eof()) { | |||
| if (blk->eof()) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | |||
| } else if (blk->eoe()) { | |||
| if (AllowCacheMiss()) { | |||
| @@ -196,7 +186,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| } | |||
| Status CacheBase::RegisterResources() { | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); | |||
| @@ -26,6 +26,7 @@ | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -87,6 +88,7 @@ class CacheBase : public ParallelOp { | |||
| int64_t row_cnt_; | |||
| std::atomic<int64_t> num_cache_miss_; | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| WaitPost epoch_sync_; | |||
| int32_t rows_per_buffer_; | |||
| Connector<std::vector<row_id_type>> keys_miss_; | |||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||
| @@ -108,6 +110,7 @@ class CacheBase : public ParallelOp { | |||
| private: | |||
| constexpr static int32_t connector_capacity_ = 1024; | |||
| int32_t prefetch_size_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | |||
| std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; | |||
| @@ -434,7 +434,6 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||
| void DatasetOp::UpdateRepeatAndEpochCounter() { | |||
| op_current_repeats_++; | |||
| if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; | |||
| MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -51,7 +51,15 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_; | |||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ | |||
| << "\nLeaf Nodes in execution path:"; | |||
| if (!eoe_ops_.empty()) { | |||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||
| } | |||
| } else { | |||
| out << " None."; | |||
| } | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| @@ -86,6 +94,13 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) { | |||
| // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | |||
| state_ = OpState::kDeOpIdle; | |||
| if (repeat_count_ != num_repeats_) { | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -123,6 +123,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); | |||
| if (in_buffer->eoe()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); | |||
| UpdateRepeatAndEpochCounter(); | |||
| continue; | |||
| } else if (in_buffer->eof()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); | |||
| @@ -199,7 +200,6 @@ Status FilterOp::Collector() { | |||
| RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); | |||
| if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || | |||
| in_pair.second == filterCtrl::kFilterEoe) { | |||
| if (in_pair.second == filterCtrl::kFilterEoe) UpdateRepeatAndEpochCounter(); | |||
| uint32_t out_task_id = out_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); | |||
| out_id_cnt++; | |||
| @@ -228,6 +228,12 @@ class MapOp : public ParallelOp { | |||
| // Indices of the columns to process. | |||
| std::vector<size_t> to_process_indices_; | |||
| // Wait post used to perform the pausing logic in MapOp | |||
| WaitPost wait_for_workers_post_; | |||
| // Count number of workers that have signaled master | |||
| std::atomic_int num_workers_paused_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of MapOp: getting the data from previous Op, validating user specified column names, | |||
| // applying a list of TensorOps to each of the data, process the results and then | |||
| @@ -31,9 +31,7 @@ ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shar | |||
| num_workers_(num_workers), | |||
| num_producers_(num_workers), | |||
| worker_connector_size_(1), | |||
| worker_connector_(nullptr), | |||
| num_workers_paused_(0), | |||
| epoch_sync_flag_(false) {} | |||
| worker_connector_(nullptr) {} | |||
| // Creates the internal worker connector for the parallel op if the derived class wants to use it | |||
| Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { | |||
| @@ -84,15 +82,5 @@ Status ParallelOp::RegisterWorkerConnectors() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ParallelOp::WaitForWorkers() { | |||
| num_workers_paused_ = 0; | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[i]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait))); | |||
| } | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Wait()); | |||
| wait_for_workers_post_.Clear(); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -21,7 +21,6 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -118,27 +117,10 @@ class ParallelOp : public DatasetOp { | |||
| // @return Status - The error code return | |||
| virtual Status WorkerEntry(int32_t workerId) = 0; | |||
| /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | |||
| /// The expected behavior is this, when this function is invoked, this function will block until all the workers | |||
| /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. | |||
| /// They would automatically wait on the QueueList when they are done. | |||
| /// \return Status | |||
| Status WaitForWorkers() override; | |||
| // Wait post used to perform the pausing logic | |||
| WaitPost wait_for_workers_post_; | |||
| // Count number of workers that have signaled master | |||
| std::atomic_int num_workers_paused_; | |||
| // Whether or not to sync worker threads at the end of each epoch | |||
| bool epoch_sync_flag_; | |||
| int32_t num_workers_; // The number of worker threads | |||
| int32_t num_producers_; // The number of threads pushing to the out_connector_ | |||
| int32_t worker_connector_size_; | |||
| std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -62,7 +62,15 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_; | |||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ | |||
| << "\nLeaf Nodes in execution path:"; | |||
| if (!eoe_ops_.empty()) { | |||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||
| } | |||
| } else { | |||
| out << " None."; | |||
| } | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| @@ -100,6 +108,7 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||
| // Base-class override for handling cases when an eoe is received. | |||
| Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| repeat_count_++; | |||
| MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ | |||
| << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | |||
| @@ -107,9 +116,15 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| if (repeat_count_ == num_repeats_) { | |||
| repeat_count_ = 0; | |||
| state_ = OpState::kDeOpIdle; | |||
| } else { | |||
| state_ = OpState::kDeOpRunning; | |||
| return Status::OK(); | |||
| } | |||
| // Invoke a reset against the eoe nodes only. | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -138,6 +153,19 @@ int32_t RepeatOp::num_consumers() const { | |||
| } | |||
| } | |||
| // Drive reset actions if needed | |||
| Status RepeatOp::Reset() { | |||
| // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. | |||
| // In that case, we now have to bounce the reset down to our own eoe ops. | |||
| MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| state_ = OpState::kDeOpRunning; | |||
| return Status::OK(); | |||
| } | |||
| int32_t RepeatOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; | |||
| @@ -101,6 +101,10 @@ class RepeatOp : public PipelineOp { | |||
| // @param worker_id - The worker id | |||
| Status EofReceived(int32_t worker_id) override; | |||
| /// \brief reset Op | |||
| /// \@return Status - The error code return | |||
| Status Reset() override; | |||
| // Base-class override. Return the number of workers in the first parent. | |||
| // @param workerId - The worker id | |||
| int32_t num_consumers() const override; | |||
| @@ -129,6 +133,10 @@ class RepeatOp : public PipelineOp { | |||
| /// \return The number of repeats that the user requested | |||
| int32_t num_repeats() { return num_repeats_; } | |||
| // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||
| // \param[in] eoe_op The input leaf/eoe operator to add to the list | |||
| void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | |||
| protected: | |||
| // The number of repeats that the user requested. | |||
| // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. | |||
| @@ -139,6 +147,7 @@ class RepeatOp : public PipelineOp { | |||
| // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class | |||
| // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. | |||
| int32_t repeat_count_; | |||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -165,19 +165,11 @@ Status AlbumOp::operator()() { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. | |||
| } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -192,13 +184,7 @@ Status AlbumOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -486,9 +472,9 @@ void AlbumOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status AlbumOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -504,7 +490,7 @@ Status AlbumOp::LaunchThreadsAndInitOp() { | |||
| } | |||
| // registers QueueList and individual Queues for interrupt services | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| // launch main workers that load DataBuffers by reading all images | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&AlbumOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| @@ -30,6 +30,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -288,7 +289,9 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t sampler_ind_; | |||
| int64_t dirname_offset_; | |||
| WaitPost wp_; | |||
| std::vector<std::string> image_rows_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -94,7 +94,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() { | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); | |||
| @@ -310,19 +310,11 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. | |||
| } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -335,13 +327,7 @@ Status CelebAOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -423,8 +409,8 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status CelebAOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -229,6 +229,8 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_; | |||
| int64_t num_rows_in_attr_file_; // rows number specified in attr file | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| WaitPost wp_; | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_; | |||
| std::string usage_; | |||
| std::ifstream partition_file_; | |||
| @@ -140,19 +140,11 @@ Status CifarOp::operator()() { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. | |||
| } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -164,7 +156,7 @@ Status CifarOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); | |||
| @@ -183,13 +175,7 @@ Status CifarOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -257,9 +243,9 @@ void CifarOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status CifarOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -232,10 +233,11 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t rows_per_buffer_; | |||
| std::string folder_path_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| int64_t row_cnt_; | |||
| int64_t buf_cnt_; | |||
| const std::string usage_; // can only be either "train" or "test" | |||
| WaitPost wp_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_; | |||
| std::vector<std::string> cifar_files_; | |||
| std::vector<std::pair<std::shared_ptr<Tensor>, std::vector<uint32_t>>> cifar_image_label_pairs_; | |||
| @@ -119,7 +119,6 @@ Status ClueOp::Init() { | |||
| } | |||
| Status ClueOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| @@ -275,8 +274,6 @@ Status ClueOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -185,16 +185,8 @@ Status CocoOp::operator()() { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -216,9 +208,9 @@ void CocoOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| Status CocoOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -385,13 +377,7 @@ Status CocoOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -627,7 +613,7 @@ Status CocoOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(this->ParseAnnotationIds()); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -326,8 +327,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| std::shared_ptr<Sampler> sampler_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| WaitPost wp_; | |||
| std::vector<std::string> image_ids_; | |||
| std::map<int32_t, std::string> image_index_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> label_index_; | |||
| std::map<std::string, CoordinateRow> coordinate_map_; | |||
| std::map<std::string, std::vector<uint32_t>> simple_item_map_; | |||
| @@ -479,7 +479,6 @@ Status CsvOp::CsvParser::InitCsvParser() { | |||
| } | |||
| Status CsvOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| @@ -576,8 +575,6 @@ Status CsvOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -186,6 +186,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { | |||
| Status GeneratorOp::operator()() { | |||
| // Handshake with TaskManager to synchronize thread creation | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| std::unique_ptr<DataBuffer> fetched_buffer; | |||
| bool eof = false; | |||
| while (!eof) { | |||
| @@ -227,8 +228,12 @@ Status GeneratorOp::operator()() { | |||
| MS_LOG(DEBUG) << "Generator operator main execution loop complete."; | |||
| eof = true; | |||
| } else { | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| // Waiting for repeatOp to start new epoch | |||
| // If Reset() is called first by repeat op, this wait() will return right away. | |||
| // If Reset() is not called yet, this wait() will block until reset. | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| // Clear the status of the wait post | |||
| wp_.Clear(); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -238,8 +243,9 @@ Status GeneratorOp::operator()() { | |||
| Status GeneratorOp::Reset() { | |||
| // Reset Op state | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(this->Init()); | |||
| // Wake up master thread | |||
| wp_.Set(); | |||
| return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | |||
| } | |||
| @@ -144,6 +144,8 @@ class GeneratorOp : public PipelineOp { | |||
| py::object generator_; | |||
| int32_t buffer_id_; | |||
| WaitPost wp_; | |||
| Status Init(); | |||
| void Dealloc() noexcept; | |||
| @@ -164,19 +164,11 @@ Status ImageFolderOp::operator()() { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. | |||
| } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -191,13 +183,7 @@ Status ImageFolderOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -261,9 +247,9 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status ImageFolderOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -379,7 +365,7 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() { | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| // The following code launch 3 threads group | |||
| // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. | |||
| // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue | |||
| @@ -29,6 +29,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -262,7 +263,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t sampler_ind_; | |||
| int64_t dirname_offset_; | |||
| WaitPost wp_; | |||
| std::vector<ImageLabelPair> image_label_pairs_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| std::unique_ptr<Queue<std::string>> folder_name_queue_; | |||
| std::unique_ptr<Queue<FolderImagesPair>> image_name_queue_; | |||
| }; | |||
| @@ -33,9 +33,8 @@ class IOBlock { | |||
| public: | |||
| enum IOBlockFlags : uint32_t { | |||
| kDeIoBlockNone = 0, | |||
| kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch | |||
| kDeIoBlockFlagEof = 1u << 1, // end of IOBlocks for entire program | |||
| kDeIoBlockFlagWait = 1u << 2 // control signal for workers to suspend operations | |||
| kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch | |||
| kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program | |||
| }; | |||
| // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. | |||
| @@ -74,10 +73,6 @@ class IOBlock { | |||
| // @return T/F if the IOBlock is eof | |||
| bool eof() const { return static_cast<uint32_t>(io_block_flags_) & static_cast<uint32_t>(kDeIoBlockFlagEof); } | |||
| // Does this block have the wait flag turned on? | |||
| // @return T/F is the IOBlock is wait | |||
| bool wait() const { return static_cast<uint32_t>(io_block_flags_) & static_cast<uint32_t>(kDeIoBlockFlagWait); } | |||
| // Adds a key to this block | |||
| // @param key - The key to add to this block | |||
| void AddKey(int64_t key) { index_keys_.push_back(key); } | |||
| @@ -127,16 +127,8 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -148,7 +140,7 @@ Status ManifestOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); | |||
| @@ -167,13 +159,7 @@ Status ManifestOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -249,9 +235,9 @@ void ManifestOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status ManifestOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -241,6 +242,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| std::string usage_; | |||
| int64_t buf_cnt_; | |||
| WaitPost wp_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::map<std::string, int32_t> label_index_; | |||
| std::vector<std::pair<std::string, std::vector<std::string>>> image_labelname_; | |||
| }; | |||
| @@ -129,9 +129,7 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf | |||
| num_padded_(num_padded), | |||
| sample_json_(sample_json), | |||
| sample_bytes_(sample_bytes) { | |||
| io_block_queues_.Init(num_workers_, op_connector_queue_size); | |||
| epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all | |||
| // tasks are consumed by the worker threads would cause problem. | |||
| io_blk_queues_.Init(num_workers_, op_connector_queue_size); | |||
| } | |||
| // Private helper method to encapsulate some common construction/reset tasks | |||
| @@ -221,27 +219,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { | |||
| Status MindRecordOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait()) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| if (io_block->eoe()) { | |||
| RETURN_IF_NOT_OK( | |||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| if (io_block->eof()) { | |||
| RETURN_IF_NOT_OK( | |||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| @@ -266,7 +255,7 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) { | |||
| } | |||
| RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| } | |||
| RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker."); | |||
| } | |||
| @@ -388,31 +377,27 @@ Status MindRecordOp::operator()() { | |||
| while (true) { // each iterator is 1 epoch | |||
| for (int32_t i = 0; i < buffers_needed_; ++i) { | |||
| std::vector<int64_t> keys(1, i); | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( | |||
| RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( | |||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | |||
| } | |||
| if (IsLastIteration()) { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[i]->Add( | |||
| RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( | |||
| std::move(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)))); | |||
| } | |||
| return Status::OK(); | |||
| } else { | |||
| } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| // reset our buffer count and go to loop again. | |||
| RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); | |||
| shard_reader_wait_post_.Clear(); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset()); | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| } | |||
| @@ -421,10 +406,10 @@ Status MindRecordOp::operator()() { | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| Status MindRecordOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. | |||
| shard_reader_->ShuffleTask(); | |||
| shard_reader_wait_post_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -434,8 +419,8 @@ Status MindRecordOp::LaunchThreadAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); | |||
| if (shard_reader_->Launch(true) == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); | |||
| } | |||
| @@ -29,6 +29,7 @@ | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| @@ -246,6 +247,8 @@ class MindRecordOp : public ParallelOp { | |||
| std::vector<int32_t> columns_blob_index_; // Blob Columns to load from dataset | |||
| std::unique_ptr<ShardReader> shard_reader_; | |||
| WaitPost shard_reader_wait_post_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_blk_queues_; | |||
| std::mutex ended_worker_mutex_; | |||
| }; | |||
| @@ -135,16 +135,8 @@ Status MnistOp::operator()() { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -158,13 +150,7 @@ Status MnistOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> iOBlock; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); | |||
| while (iOBlock != nullptr) { | |||
| if (iOBlock->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (iOBlock->eoe() == true) { | |||
| if (iOBlock->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (iOBlock->eof() == true) { | |||
| @@ -222,9 +208,9 @@ void MnistOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status MnistOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -414,7 +400,7 @@ Status MnistOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(this->WalkAllFiles()); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -244,6 +245,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t row_cnt_; | |||
| WaitPost wp_; | |||
| std::string folder_path_; // directory of image folder | |||
| int32_t rows_per_buffer_; | |||
| const std::string usage_; // can only be either "train" or "test" | |||
| @@ -251,6 +253,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| std::vector<MnistLabelPair> image_label_pairs_; | |||
| std::vector<std::string> image_names_; | |||
| std::vector<std::string> label_names_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -239,15 +239,10 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { | |||
| } | |||
| } | |||
| // Wait for the reset to wake us up if we're not quitting | |||
| if (!(*quitting)) { | |||
| MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; | |||
| if (last_guy_in) { | |||
| // If we are the last worker, do reset to wake other workers up | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } else { | |||
| // If we are not the last worker, wait for the reset | |||
| RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); | |||
| } | |||
| RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); | |||
| prev = guys_out_.fetch_add(1); | |||
| bool last_guy_out = (prev + 1) == num_workers_; | |||
| // Last guy out will clear the wait post and set the row counts | |||
| @@ -370,7 +365,7 @@ Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| Status RandomDataOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| MS_LOG(INFO) << "RandomDataOp resetting."; | |||
| // Ensure all guys are in the waitpost | |||
| if (guys_in_ != num_workers_) { | |||
| @@ -136,7 +136,6 @@ Status TextFileOp::Init() { | |||
| } | |||
| Status TextFileOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| @@ -433,8 +432,6 @@ Status TextFileOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -27,6 +27,7 @@ | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| #include "minddata/dataset/engine/jagged_connector.h" | |||
| @@ -316,8 +316,6 @@ Status TFReaderOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -710,7 +708,6 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table | |||
| // Overrides base class reset method. Cleans up any state info from it's previous execution and | |||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||
| Status TFReaderOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true | |||
| load_jagged_connector_ = true; | |||
| @@ -164,16 +164,8 @@ Status VOCOp::operator()() { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -195,9 +187,9 @@ void VOCOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| Status VOCOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -243,13 +235,7 @@ Status VOCOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -385,7 +371,7 @@ Status VOCOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(this->ParseImageIds()); | |||
| @@ -26,6 +26,7 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -282,7 +283,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t rows_per_buffer_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| WaitPost wp_; | |||
| std::vector<std::string> image_ids_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::map<std::string, int32_t> class_index_; | |||
| std::map<std::string, int32_t> label_index_; | |||
| std::map<std::string, Annotation> annotation_map_; | |||
| @@ -27,10 +27,25 @@ namespace mindspore { | |||
| namespace dataset { | |||
| RepeatPass::RepeatPass() | |||
| : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} | |||
| : is_repeated_(false), | |||
| nested_repeats_(0), | |||
| num_repeats_(1), | |||
| num_epochs_(1), | |||
| is_merge_(false), | |||
| is_cached_(false), | |||
| cache_lookup_(nullptr) {} | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||
| eoe_op_stacks_.push(std::move(new_stack)); | |||
| // If we are already repeated, then this is a nested repeat. | |||
| if (is_repeated_) { | |||
| nested_repeats_++; | |||
| } | |||
| is_repeated_ = true; | |||
| // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | |||
| // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | |||
| if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | |||
| @@ -58,7 +73,9 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie | |||
| // that RepeatOp does. However, epoch control is actually simpler because it can | |||
| // only exist as the root node so it doesn't need all the nested code. | |||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||
| eoe_op_stacks_.push(std::move(new_stack)); | |||
| is_repeated_ = true; | |||
| // Get the total number of epochs from the EpochCtrlOp parameter | |||
| num_epochs_ = node->num_repeats(); | |||
| // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | |||
| @@ -85,16 +102,44 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | |||
| // at this time, so we can pop it to get rid of it. | |||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| if (!current_stack->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | |||
| } | |||
| eoe_op_stacks_.pop(); | |||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | |||
| // and set its total repeats. It is important that the op is removed from the save area, | |||
| // because the merge op above us may also take action on it later for a different case when | |||
| // there is no repeat in the merge leg. | |||
| // and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed | |||
| // from the save area, because the merge op above us may also take action on it later for a different | |||
| // case when there is no repeat in the merge leg. | |||
| if (is_merge_ && cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| cache_lookup_.reset(); | |||
| node->AddToEoeList(std::move(cache_lookup_)); | |||
| } | |||
| // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | |||
| // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | |||
| if (nested_repeats_ > 0) { | |||
| AddToEOEOpStack(node); | |||
| nested_repeats_--; | |||
| } else { | |||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||
| if (nested_repeats_ != 0) { | |||
| RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); | |||
| } | |||
| is_repeated_ = false; | |||
| } | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| } | |||
| @@ -110,6 +155,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| is_repeated_ = false; | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| // We finish the walk of this EpochCtrl's descendent nodes. | |||
| @@ -120,17 +172,31 @@ Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) | |||
| // CacheOp removes previous leaf ops and replaces them with itself | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| is_cached_ = false; | |||
| if (is_repeated_) { | |||
| // if we are a cache within a repeat path of the tree, then there will be | |||
| // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the | |||
| // repeat or epoch ctrl operators can work with them for repeat activity during runtime. | |||
| // However, since a cache is present: | |||
| // - unflag those ops as being repeated ops | |||
| // - remove them from the eoe op stack so that repeat op above in the tree won't know about them | |||
| // - add ourself (the cache op), as an eoe op | |||
| // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead | |||
| // the repeating behaviours shall be invoked against the cache op. | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node)); | |||
| // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops. | |||
| // So that those cached nodes become 1-time use (up to eoe), never repeated. Instead | |||
| // the repeating behaviours shall be invoked against the cache op. | |||
| std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack(); | |||
| while (cached_op != nullptr) { | |||
| int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; | |||
| cached_op->set_total_repeats(cached_op_total_repeats); | |||
| // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 | |||
| cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); | |||
| cached_op = PopFromCachedOpStack(); | |||
| // adjust the total epochs and total repeats for ops under this cache op | |||
| std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack(); | |||
| while (cached_op != nullptr) { | |||
| int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; | |||
| cached_op->set_total_repeats(cached_op_total_repeats); | |||
| // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 | |||
| cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); | |||
| cached_op = PopFromCachedOpStack(); | |||
| } | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| @@ -141,7 +207,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| // for use with a controlling repeat above it. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // If we are under a cache op, then save ourselves to the cached op stack. | |||
| // If we are in a repeat path, then set our repeated flag | |||
| if (is_repeated_) { | |||
| // if we are a leaf node then save ourself in a stack for the repeat operator above us | |||
| if (node->IsLeaf()) { | |||
| AddToEOEOpStack(node); | |||
| } | |||
| } | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| } | |||
| @@ -153,11 +225,15 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // Turns off the tracking for operations under merge op | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to set its total repeats for it. | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | |||
| if (is_repeated_) { | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| AddToEOEOpStack(std::move(cache_lookup_)); | |||
| } | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| @@ -190,6 +266,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the eoe operator stack save area | |||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | |||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| current_stack->push(dataset_op); | |||
| } | |||
| // Pops an operator from the eoe operator stack save area | |||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | |||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| if (current_stack != nullptr && !current_stack->empty()) { | |||
| top_op = current_stack->top(); | |||
| current_stack->pop(); | |||
| } | |||
| return top_op; | |||
| } | |||
| // Adds an operator to the cached operator stack save area | |||
| void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | |||
| @@ -106,6 +106,15 @@ class RepeatPass : public NodePass { | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | |||
| private: | |||
| /// \brief Adds an operator to the eoe operator stack save area | |||
| /// \param op - The dataset op to work add to eoe stack | |||
| /// \return Status - The error code return | |||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// \brief Pops an operator from the eoe operator stack save area | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||
| /// \brief Adds an operator to the cached operator stack save area | |||
| /// \param op - The dataset op to work add to cached stack | |||
| /// \return Status - The error code return | |||
| @@ -115,12 +124,15 @@ class RepeatPass : public NodePass { | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromCachedOpStack(); | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| bool is_cached_; // T/F is we are processing under a cache op | |||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||
| int32_t num_epochs_; // To save the total number of epochs | |||
| op_stack cached_op_stacks_; // A save area for ops under a cache op | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| bool is_repeated_; // T/F if we are processing under a repeat | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| bool is_cached_; // T/F is we are processing under a cache op | |||
| int32_t nested_repeats_; // A counter for nested repeats | |||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||
| int32_t num_epochs_; // To save the total number of epochs | |||
| std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||
| op_stack cached_op_stacks_; // A save area for ops under a cache op | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||