Merge pull request !3974 from lixiachen/repeat_task2tags/v1.0.0
| @@ -218,16 +218,6 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba | |||
| return Status::OK(); | |||
| } | |||
| Status BucketBatchByLengthOp::Reset() { | |||
| batch_count_ = 0; | |||
| for (int i = 0; i < buckets_.size(); i++) { | |||
| buckets_[i] = std::make_unique<TensorQTable>(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Computing the assignment of the column name map and check compute input columns. | |||
| Status BucketBatchByLengthOp::ComputeColMap() { | |||
| RETURN_IF_NOT_OK(DatasetOp::ComputeColMap()); | |||
| @@ -126,10 +126,6 @@ class BucketBatchByLengthOp : public PipelineOp { | |||
| // @return Status - The error code returned | |||
| Status operator()() override; | |||
| // Function that is called by ResetOp at the end of every epoch | |||
| // @return Status - The error code returned | |||
| Status Reset() override; | |||
| private: | |||
| Status ObtainElementLength(int32_t *out_element_length, TensorRow element); | |||
| @@ -42,8 +42,7 @@ Status CacheBase::Reset() { | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| } | |||
| // Wake up the workers to get them going again in a new epoch | |||
| MS_LOG(DEBUG) << Name() << " resetting."; | |||
| epoch_sync_.Set(); | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| return Status::OK(); | |||
| } | |||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| @@ -72,7 +71,6 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | |||
| // to the WorkerEntry. | |||
| do { | |||
| epoch_sync_.Clear(); | |||
| if (AllowCacheMiss() && wait_cnt > 0) { | |||
| MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ | |||
| << " Total number of rows : " << row_cnt_; | |||
| @@ -112,11 +110,17 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| // If repeat but the not last repeat, wait for reset. | |||
| if (!IsLastIteration()) { | |||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | |||
| RETURN_IF_NOT_OK(epoch_sync_.Wait()); | |||
| } else { | |||
| // We can break out from the loop. | |||
| break; | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset()); | |||
| UpdateRepeatAndEpochCounter(); | |||
| } while (true); | |||
| // Flow the eof before exit | |||
| @@ -142,7 +146,13 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> blk; | |||
| do { | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||
| if (blk->eof()) { | |||
| if (blk->wait()) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (blk->eof()) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | |||
| } else if (blk->eoe()) { | |||
| if (AllowCacheMiss()) { | |||
| @@ -186,7 +196,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| } | |||
| Status CacheBase::RegisterResources() { | |||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -88,7 +87,6 @@ class CacheBase : public ParallelOp { | |||
| int64_t row_cnt_; | |||
| std::atomic<int64_t> num_cache_miss_; | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| WaitPost epoch_sync_; | |||
| int32_t rows_per_buffer_; | |||
| Connector<std::vector<row_id_type>> keys_miss_; | |||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||
| @@ -110,7 +108,6 @@ class CacheBase : public ParallelOp { | |||
| private: | |||
| constexpr static int32_t connector_capacity_ = 1024; | |||
| int32_t prefetch_size_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | |||
| std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; | |||
| @@ -434,6 +434,7 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||
| void DatasetOp::UpdateRepeatAndEpochCounter() { | |||
| op_current_repeats_++; | |||
| if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; | |||
| MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -51,15 +51,7 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ | |||
| << "\nLeaf Nodes in execution path:"; | |||
| if (!eoe_ops_.empty()) { | |||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||
| } | |||
| } else { | |||
| out << " None."; | |||
| } | |||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_; | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| @@ -94,13 +86,6 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) { | |||
| // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | |||
| state_ = OpState::kDeOpIdle; | |||
| if (repeat_count_ != num_repeats_) { | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -123,7 +123,6 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); | |||
| if (in_buffer->eoe()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); | |||
| UpdateRepeatAndEpochCounter(); | |||
| continue; | |||
| } else if (in_buffer->eof()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); | |||
| @@ -200,6 +199,7 @@ Status FilterOp::Collector() { | |||
| RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); | |||
| if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || | |||
| in_pair.second == filterCtrl::kFilterEoe) { | |||
| if (in_pair.second == filterCtrl::kFilterEoe) UpdateRepeatAndEpochCounter(); | |||
| uint32_t out_task_id = out_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); | |||
| out_id_cnt++; | |||
| @@ -228,12 +228,6 @@ class MapOp : public ParallelOp { | |||
| // Indices of the columns to process. | |||
| std::vector<size_t> to_process_indices_; | |||
| // Wait post used to perform the pausing logic in MapOp | |||
| WaitPost wait_for_workers_post_; | |||
| // Count number of workers that have signaled master | |||
| std::atomic_int num_workers_paused_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of MapOp: getting the data from previous Op, validating user specified column names, | |||
| // applying a list of TensorOps to each of the data, process the results and then | |||
| @@ -31,7 +31,9 @@ ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shar | |||
| num_workers_(num_workers), | |||
| num_producers_(num_workers), | |||
| worker_connector_size_(1), | |||
| worker_connector_(nullptr) {} | |||
| worker_connector_(nullptr), | |||
| num_workers_paused_(0), | |||
| epoch_sync_flag_(false) {} | |||
| // Creates the internal worker connector for the parallel op if the derived class wants to use it | |||
| Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { | |||
| @@ -82,5 +84,15 @@ Status ParallelOp::RegisterWorkerConnectors() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ParallelOp::WaitForWorkers() { | |||
| num_workers_paused_ = 0; | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[i]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait))); | |||
| } | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Wait()); | |||
| wait_for_workers_post_.Clear(); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -117,10 +118,27 @@ class ParallelOp : public DatasetOp { | |||
| // @return Status - The error code return | |||
| virtual Status WorkerEntry(int32_t workerId) = 0; | |||
| /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | |||
| /// The expected behavior is this, when this function is invoked, this function will block until all the workers | |||
| /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. | |||
| /// They would automatically wait on the QueueList when they are done. | |||
| /// \return Status | |||
| Status WaitForWorkers() override; | |||
| // Wait post used to perform the pausing logic | |||
| WaitPost wait_for_workers_post_; | |||
| // Count number of workers that have signaled master | |||
| std::atomic_int num_workers_paused_; | |||
| // Whether or not to sync worker threads at the end of each epoch | |||
| bool epoch_sync_flag_; | |||
| int32_t num_workers_; // The number of worker threads | |||
| int32_t num_producers_; // The number of threads pushing to the out_connector_ | |||
| int32_t worker_connector_size_; | |||
| std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads | |||
| std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -62,15 +62,7 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ | |||
| << "\nLeaf Nodes in execution path:"; | |||
| if (!eoe_ops_.empty()) { | |||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||
| } | |||
| } else { | |||
| out << " None."; | |||
| } | |||
| out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_; | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| @@ -108,7 +100,6 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||
| // Base-class override for handling cases when an eoe is received. | |||
| Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| repeat_count_++; | |||
| MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ | |||
| << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | |||
| @@ -116,15 +107,9 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| if (repeat_count_ == num_repeats_) { | |||
| repeat_count_ = 0; | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| // Invoke a reset against the eoe nodes only. | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } else { | |||
| state_ = OpState::kDeOpRunning; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -153,19 +138,6 @@ int32_t RepeatOp::num_consumers() const { | |||
| } | |||
| } | |||
| // Drive reset actions if needed | |||
| Status RepeatOp::Reset() { | |||
| // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. | |||
| // In that case, we now have to bounce the reset down to our own eoe ops. | |||
| MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| state_ = OpState::kDeOpRunning; | |||
| return Status::OK(); | |||
| } | |||
| int32_t RepeatOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; | |||
| @@ -101,10 +101,6 @@ class RepeatOp : public PipelineOp { | |||
| // @param worker_id - The worker id | |||
| Status EofReceived(int32_t worker_id) override; | |||
| /// \brief reset Op | |||
| /// \@return Status - The error code return | |||
| Status Reset() override; | |||
| // Base-class override. Return the number of workers in the first parent. | |||
| // @param workerId - The worker id | |||
| int32_t num_consumers() const override; | |||
| @@ -133,10 +129,6 @@ class RepeatOp : public PipelineOp { | |||
| /// \return The number of repeats that the user requested | |||
| int32_t num_repeats() { return num_repeats_; } | |||
| // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||
| // \param[in] eoe_op The input leaf/eoe operator to add to the list | |||
| void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | |||
| protected: | |||
| // The number of repeats that the user requested. | |||
| // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. | |||
| @@ -147,7 +139,6 @@ class RepeatOp : public PipelineOp { | |||
| // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class | |||
| // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. | |||
| int32_t repeat_count_; | |||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -161,11 +161,19 @@ Status AlbumOp::operator()() { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset | |||
| } else { // not the last repeat. | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -180,7 +188,13 @@ Status AlbumOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -468,9 +482,9 @@ void AlbumOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status AlbumOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -486,7 +500,7 @@ Status AlbumOp::LaunchThreadsAndInitOp() { | |||
| } | |||
| // registers QueueList and individual Queues for interrupt services | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| // launch main workers that load DataBuffers by reading all images | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&AlbumOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| @@ -30,7 +30,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -289,9 +288,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t sampler_ind_; | |||
| int64_t dirname_offset_; | |||
| WaitPost wp_; | |||
| std::vector<std::string> image_rows_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -94,7 +94,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() { | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); | |||
| @@ -311,11 +311,19 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset | |||
| } else { // not the last repeat. | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -328,7 +336,13 @@ Status CelebAOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -410,8 +424,8 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status CelebAOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -229,8 +229,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_; | |||
| int64_t num_rows_in_attr_file_; // rows number specified in attr file | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| WaitPost wp_; | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_; | |||
| std::string usage_; | |||
| std::ifstream partition_file_; | |||
| @@ -140,11 +140,19 @@ Status CifarOp::operator()() { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset | |||
| } else { // not the last repeat. | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -156,7 +164,7 @@ Status CifarOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); | |||
| @@ -175,7 +183,13 @@ Status CifarOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -243,9 +257,9 @@ void CifarOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status CifarOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -233,11 +232,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t rows_per_buffer_; | |||
| std::string folder_path_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| int64_t row_cnt_; | |||
| int64_t buf_cnt_; | |||
| const std::string usage_; // can only be either "train" or "test" | |||
| WaitPost wp_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_; | |||
| std::vector<std::string> cifar_files_; | |||
| std::vector<std::pair<std::shared_ptr<Tensor>, std::vector<uint32_t>>> cifar_image_label_pairs_; | |||
| @@ -119,6 +119,7 @@ Status ClueOp::Init() { | |||
| } | |||
| Status ClueOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| @@ -274,6 +275,8 @@ Status ClueOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -25,7 +25,6 @@ | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -185,8 +185,16 @@ Status CocoOp::operator()() { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -208,9 +216,9 @@ void CocoOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| Status CocoOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -377,7 +385,13 @@ Status CocoOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -609,7 +623,7 @@ Status CocoOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(this->ParseAnnotationIds()); | |||
| @@ -27,7 +27,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -327,10 +326,8 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| std::shared_ptr<Sampler> sampler_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| WaitPost wp_; | |||
| std::vector<std::string> image_ids_; | |||
| std::map<int32_t, std::string> image_index_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> label_index_; | |||
| std::map<std::string, CoordinateRow> coordinate_map_; | |||
| std::map<std::string, std::vector<uint32_t>> simple_item_map_; | |||
| @@ -479,6 +479,7 @@ Status CsvOp::CsvParser::InitCsvParser() { | |||
| } | |||
| Status CsvOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| @@ -572,6 +573,8 @@ Status CsvOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -186,7 +186,6 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { | |||
| Status GeneratorOp::operator()() { | |||
| // Handshake with TaskManager to synchronize thread creation | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| std::unique_ptr<DataBuffer> fetched_buffer; | |||
| bool eof = false; | |||
| while (!eof) { | |||
| @@ -228,12 +227,8 @@ Status GeneratorOp::operator()() { | |||
| MS_LOG(DEBUG) << "Generator operator main execution loop complete."; | |||
| eof = true; | |||
| } else { | |||
| // Waiting for repeatOp to start new epoch | |||
| // If Reset() is called first by repeat op, this wait() will return right away. | |||
| // If Reset() is not called yet, this wait() will block until reset. | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| // Clear the status of the wait post | |||
| wp_.Clear(); | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -243,9 +238,8 @@ Status GeneratorOp::operator()() { | |||
| Status GeneratorOp::Reset() { | |||
| // Reset Op state | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(this->Init()); | |||
| // Wake up master thread | |||
| wp_.Set(); | |||
| return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | |||
| } | |||
| @@ -144,8 +144,6 @@ class GeneratorOp : public PipelineOp { | |||
| py::object generator_; | |||
| int32_t buffer_id_; | |||
| WaitPost wp_; | |||
| Status Init(); | |||
| void Dealloc() noexcept; | |||
| @@ -164,11 +164,19 @@ Status ImageFolderOp::operator()() { | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset | |||
| } else { // not the last repeat. | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -183,7 +191,13 @@ Status ImageFolderOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -247,9 +261,9 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status ImageFolderOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -365,7 +379,7 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() { | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| // The following code launch 3 threads group | |||
| // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. | |||
| // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue | |||
| @@ -29,7 +29,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -263,9 +262,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t sampler_ind_; | |||
| int64_t dirname_offset_; | |||
| WaitPost wp_; | |||
| std::vector<ImageLabelPair> image_label_pairs_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| std::unique_ptr<Queue<std::string>> folder_name_queue_; | |||
| std::unique_ptr<Queue<FolderImagesPair>> image_name_queue_; | |||
| }; | |||
| @@ -33,8 +33,9 @@ class IOBlock { | |||
| public: | |||
| enum IOBlockFlags : uint32_t { | |||
| kDeIoBlockNone = 0, | |||
| kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch | |||
| kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program | |||
| kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch | |||
| kDeIoBlockFlagEof = 1u << 1, // end of IOBlocks for entire program | |||
| kDeIoBlockFlagWait = 1u << 2 // control signal for workers to suspend operations | |||
| }; | |||
| // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. | |||
| @@ -73,6 +74,10 @@ class IOBlock { | |||
| // @return T/F if the IOBlock is eof | |||
| bool eof() const { return static_cast<uint32_t>(io_block_flags_) & static_cast<uint32_t>(kDeIoBlockFlagEof); } | |||
| // Does this block have the wait flag turned on? | |||
| // @return T/F is the IOBlock is wait | |||
| bool wait() const { return static_cast<uint32_t>(io_block_flags_) & static_cast<uint32_t>(kDeIoBlockFlagWait); } | |||
| // Adds a key to this block | |||
| // @param key - The key to add to this block | |||
| void AddKey(int64_t key) { index_keys_.push_back(key); } | |||
| @@ -127,8 +127,16 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -140,7 +148,7 @@ Status ManifestOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); | |||
| @@ -159,7 +167,13 @@ Status ManifestOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -235,9 +249,9 @@ void ManifestOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status ManifestOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -242,8 +241,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| std::string usage_; | |||
| int64_t buf_cnt_; | |||
| WaitPost wp_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::map<std::string, int32_t> label_index_; | |||
| std::vector<std::pair<std::string, std::vector<std::string>>> image_labelname_; | |||
| }; | |||
| @@ -129,7 +129,9 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf | |||
| num_padded_(num_padded), | |||
| sample_json_(sample_json), | |||
| sample_bytes_(sample_bytes) { | |||
| io_blk_queues_.Init(num_workers_, op_connector_queue_size); | |||
| io_block_queues_.Init(num_workers_, op_connector_queue_size); | |||
| epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all | |||
| // tasks are consumed by the worker threads would cause problem. | |||
| } | |||
| // Private helper method to encapsulate some common construction/reset tasks | |||
| @@ -219,18 +221,27 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { | |||
| Status MindRecordOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->wait()) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| if (io_block->eoe()) { | |||
| RETURN_IF_NOT_OK( | |||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| if (io_block->eof()) { | |||
| RETURN_IF_NOT_OK( | |||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| @@ -255,7 +266,7 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) { | |||
| } | |||
| RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| } | |||
| RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker."); | |||
| } | |||
| @@ -377,27 +388,31 @@ Status MindRecordOp::operator()() { | |||
| while (true) { // each iterator is 1 epoch | |||
| for (int32_t i = 0; i < buffers_needed_; ++i) { | |||
| std::vector<int64_t> keys(1, i); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( | |||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | |||
| } | |||
| if (IsLastIteration()) { | |||
| RETURN_IF_NOT_OK( | |||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK( | |||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( | |||
| RETURN_IF_NOT_OK(io_block_queues_[i]->Add( | |||
| std::move(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)))); | |||
| } | |||
| return Status::OK(); | |||
| } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| } | |||
| // reset our buffer count and go to loop again. | |||
| RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); | |||
| shard_reader_wait_post_.Clear(); | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset()); | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| } | |||
| @@ -406,10 +421,10 @@ Status MindRecordOp::operator()() { | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| Status MindRecordOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. | |||
| shard_reader_->ShuffleTask(); | |||
| shard_reader_wait_post_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -419,8 +434,8 @@ Status MindRecordOp::LaunchThreadAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| if (shard_reader_->Launch(true) == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); | |||
| } | |||
| @@ -29,7 +29,6 @@ | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| @@ -247,8 +246,6 @@ class MindRecordOp : public ParallelOp { | |||
| std::vector<int32_t> columns_blob_index_; // Blob Columns to load from dataset | |||
| std::unique_ptr<ShardReader> shard_reader_; | |||
| WaitPost shard_reader_wait_post_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_blk_queues_; | |||
| std::mutex ended_worker_mutex_; | |||
| }; | |||
| @@ -135,8 +135,16 @@ Status MnistOp::operator()() { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -150,7 +158,13 @@ Status MnistOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> iOBlock; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); | |||
| while (iOBlock != nullptr) { | |||
| if (iOBlock->eoe() == true) { | |||
| if (iOBlock->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (iOBlock->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (iOBlock->eof() == true) { | |||
| @@ -208,9 +222,9 @@ void MnistOp::Print(std::ostream &out, bool show_all) const { | |||
| // Reset Sampler and wakeup Master thread (functor) | |||
| Status MnistOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); // wake up master thread after reset is done | |||
| return Status::OK(); | |||
| } | |||
| @@ -401,7 +415,7 @@ Status MnistOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(this->WalkAllFiles()); | |||
| @@ -27,7 +27,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -245,7 +244,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t row_cnt_; | |||
| WaitPost wp_; | |||
| std::string folder_path_; // directory of image folder | |||
| int32_t rows_per_buffer_; | |||
| const std::string usage_; // can only be either "train" or "test" | |||
| @@ -253,7 +251,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| std::vector<MnistLabelPair> image_label_pairs_; | |||
| std::vector<std::string> image_names_; | |||
| std::vector<std::string> label_names_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -239,10 +239,15 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { | |||
| } | |||
| } | |||
| // Wait for the reset to wake us up if we're not quitting | |||
| if (!(*quitting)) { | |||
| MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; | |||
| RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); | |||
| if (last_guy_in) { | |||
| // If we are the last worker, do reset to wake other workers up | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } else { | |||
| // If we are not the last worker, wait for the reset | |||
| RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); | |||
| } | |||
| prev = guys_out_.fetch_add(1); | |||
| bool last_guy_out = (prev + 1) == num_workers_; | |||
| // Last guy out will clear the wait post and set the row counts | |||
| @@ -365,7 +370,7 @@ Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| Status RandomDataOp::Reset() { | |||
| MS_LOG(INFO) << "RandomDataOp resetting."; | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| // Ensure all guys are in the waitpost | |||
| if (guys_in_ != num_workers_) { | |||
| @@ -136,6 +136,7 @@ Status TextFileOp::Init() { | |||
| } | |||
| Status TextFileOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| @@ -432,6 +433,8 @@ Status TextFileOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -27,7 +27,6 @@ | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| #include "minddata/dataset/engine/jagged_connector.h" | |||
| @@ -317,6 +317,8 @@ Status TFReaderOp::operator()() { | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| // Self-reset to start a new iteration | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| @@ -709,6 +711,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table | |||
| // Overrides base class reset method. Cleans up any state info from it's previous execution and | |||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||
| Status TFReaderOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true | |||
| load_jagged_connector_ = true; | |||
| @@ -164,8 +164,16 @@ Status VOCOp::operator()() { | |||
| } else { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| wp_.Clear(); | |||
| } | |||
| if (epoch_sync_flag_) { | |||
| // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for | |||
| // the current epoch. | |||
| RETURN_IF_NOT_OK(WaitForWorkers()); | |||
| } | |||
| // If not the last repeat, self-reset and go to loop again. | |||
| if (!IsLastIteration()) { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| @@ -187,9 +195,9 @@ void VOCOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| Status VOCOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| row_cnt_ = 0; | |||
| wp_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -235,7 +243,13 @@ Status VOCOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->wait() == true) { | |||
| // Sync io_block is a signal that master thread wants us to pause and sync with other workers. | |||
| // The last guy who comes to this sync point should reset the counter and wake up the master thread. | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| } else if (io_block->eoe() == true) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| buffer_id = worker_id; | |||
| } else if (io_block->eof() == true) { | |||
| @@ -367,7 +381,7 @@ Status VOCOp::LaunchThreadsAndInitOp() { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(this->ParseImageIds()); | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -283,9 +282,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t rows_per_buffer_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| WaitPost wp_; | |||
| std::vector<std::string> image_ids_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::map<std::string, int32_t> class_index_; | |||
| std::map<std::string, int32_t> label_index_; | |||
| std::map<std::string, Annotation> annotation_map_; | |||
| @@ -27,25 +27,10 @@ namespace mindspore { | |||
| namespace dataset { | |||
| RepeatPass::RepeatPass() | |||
| : is_repeated_(false), | |||
| nested_repeats_(0), | |||
| num_repeats_(1), | |||
| num_epochs_(1), | |||
| is_merge_(false), | |||
| is_cached_(false), | |||
| cache_lookup_(nullptr) {} | |||
| : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||
| eoe_op_stacks_.push(std::move(new_stack)); | |||
| // If we are already repeated, then this is a nested repeat. | |||
| if (is_repeated_) { | |||
| nested_repeats_++; | |||
| } | |||
| is_repeated_ = true; | |||
| // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | |||
| // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | |||
| if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | |||
| @@ -73,9 +58,7 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie | |||
| // that RepeatOp does. However, epoch control is actually simpler because it can | |||
| // only exist as the root node so it doesn't need all the nested code. | |||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||
| eoe_op_stacks_.push(std::move(new_stack)); | |||
| is_repeated_ = true; | |||
| // Get the total number of epochs from the EpochCtrlOp parameter | |||
| num_epochs_ = node->num_repeats(); | |||
| // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | |||
| @@ -102,44 +85,16 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | |||
| // at this time, so we can pop it to get rid of it. | |||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| if (!current_stack->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | |||
| } | |||
| eoe_op_stacks_.pop(); | |||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | |||
| // and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed | |||
| // from the save area, because the merge op above us may also take action on it later for a different | |||
| // case when there is no repeat in the merge leg. | |||
| // and set its total repeats. It is important that the op is removed from the save area, | |||
| // because the merge op above us may also take action on it later for a different case when | |||
| // there is no repeat in the merge leg. | |||
| if (is_merge_ && cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| node->AddToEoeList(std::move(cache_lookup_)); | |||
| cache_lookup_.reset(); | |||
| } | |||
| // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | |||
| // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | |||
| if (nested_repeats_ > 0) { | |||
| AddToEOEOpStack(node); | |||
| nested_repeats_--; | |||
| } else { | |||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||
| if (nested_repeats_ != 0) { | |||
| RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); | |||
| } | |||
| is_repeated_ = false; | |||
| } | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| } | |||
| @@ -155,13 +110,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| is_repeated_ = false; | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| // We finish the walk of this EpochCtrl's descendent nodes. | |||
| @@ -172,31 +120,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) | |||
| // CacheOp removes previous leaf ops and replaces them with itself | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| is_cached_ = false; | |||
| if (is_repeated_) { | |||
| // if we are a cache within a repeat path of the tree, then there will be | |||
| // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the | |||
| // repeat or epoch ctrl operators can work with them for repeat activity during runtime. | |||
| // However, since a cache is present: | |||
| // - unflag those ops as being repeated ops | |||
| // - remove them from the eoe op stack so that repeat op above in the tree won't know about them | |||
| // - add ourself (the cache op), as an eoe op | |||
| // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead | |||
| // the repeating behaviours shall be invoked against the cache op. | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node)); | |||
| // adjust the total epochs and total repeats for ops under this cache op | |||
| std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack(); | |||
| while (cached_op != nullptr) { | |||
| int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; | |||
| cached_op->set_total_repeats(cached_op_total_repeats); | |||
| // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 | |||
| cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); | |||
| cached_op = PopFromCachedOpStack(); | |||
| } | |||
| // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops. | |||
| // So that those cached nodes become 1-time use (up to eoe), never repeated. Instead | |||
| // the repeating behaviours shall be invoked against the cache op. | |||
| std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack(); | |||
| while (cached_op != nullptr) { | |||
| int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; | |||
| cached_op->set_total_repeats(cached_op_total_repeats); | |||
| // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 | |||
| cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); | |||
| cached_op = PopFromCachedOpStack(); | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| @@ -207,13 +141,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| // for use with a controlling repeat above it. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // If we are in a repeat path, then set our repeated flag | |||
| if (is_repeated_) { | |||
| // if we are a leaf node then save ourself in a stack for the repeat operator above us | |||
| if (node->IsLeaf()) { | |||
| AddToEOEOpStack(node); | |||
| } | |||
| } | |||
| // If we are under a cache op, then save ourselves to the cached op stack. | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| } | |||
| @@ -225,15 +153,11 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // Turns off the tracking for operations under merge op | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | |||
| if (is_repeated_) { | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| AddToEOEOpStack(std::move(cache_lookup_)); | |||
| } | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to set its total repeats for it. | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| @@ -266,23 +190,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the eoe operator stack save area | |||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | |||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| current_stack->push(dataset_op); | |||
| } | |||
| // Pops an operator from the eoe operator stack save area | |||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | |||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| if (current_stack != nullptr && !current_stack->empty()) { | |||
| top_op = current_stack->top(); | |||
| current_stack->pop(); | |||
| } | |||
| return top_op; | |||
| } | |||
| // Adds an operator to the cached operator stack save area | |||
| void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | |||
| @@ -106,15 +106,6 @@ class RepeatPass : public NodePass { | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | |||
| private: | |||
| /// \brief Adds an operator to the eoe operator stack save area | |||
| /// \param op - The dataset op to work add to eoe stack | |||
| /// \return Status - The error code return | |||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// \brief Pops an operator from the eoe operator stack save area | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||
| /// \brief Adds an operator to the cached operator stack save area | |||
| /// \param op - The dataset op to work add to cached stack | |||
| /// \return Status - The error code return | |||
| @@ -124,15 +115,12 @@ class RepeatPass : public NodePass { | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromCachedOpStack(); | |||
| bool is_repeated_; // T/F if we are processing under a repeat | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| bool is_cached_; // T/F is we are processing under a cache op | |||
| int32_t nested_repeats_; // A counter for nested repeats | |||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||
| int32_t num_epochs_; // To save the total number of epochs | |||
| std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||
| op_stack cached_op_stacks_; // A save area for ops under a cache op | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| bool is_cached_; // T/F is we are processing under a cache op | |||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||
| int32_t num_epochs_; // To save the total number of epochs | |||
| op_stack cached_op_stacks_; // A save area for ops under a cache op | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -32,7 +32,7 @@ def test_case_0(): | |||
| ds1 = ds1.map(operations=(lambda x: x + x), input_columns=col, output_columns="out") | |||
| print("************** Output Tensor *****************") | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| print(data["out"]) | |||
| print("************** Output Tensor *****************") | |||
| @@ -52,7 +52,7 @@ def test_case_1(): | |||
| ds1 = ds1.map(operations=(lambda x: (x, x + x)), input_columns=col, output_columns=["out0", "out1"]) | |||
| print("************** Output Tensor *****************") | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| print("out0") | |||
| print(data["out0"]) | |||
| @@ -75,7 +75,7 @@ def test_case_2(): | |||
| ds1 = ds1.map(operations=(lambda x, y: x + y), input_columns=col, output_columns="out") | |||
| print("************** Output Tensor *****************") | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| print(data["out"]) | |||
| @@ -97,7 +97,7 @@ def test_case_3(): | |||
| output_columns=["out0", "out1", "out2"]) | |||
| print("************** Output Tensor *****************") | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| print("out0") | |||
| print(data["out0"]) | |||
| @@ -123,7 +123,7 @@ def test_case_4(): | |||
| output_columns=["out0", "out1", "out2"], num_parallel_workers=4) | |||
| print("************** Output Tensor *****************") | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| print("out0") | |||
| print(data["out0"]) | |||
| @@ -141,6 +141,34 @@ def test_bucket_batch_multi_bucket_no_padding(): | |||
| assert output == expected_output | |||
| def test_bucket_batch_multi_bucket_no_padding_repeat(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [1, 2, 3] | |||
| bucket_batch_sizes = [3, 3, 2, 2] | |||
| element_length_function = (lambda x: x[0] % 4) | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function) | |||
| dataset = dataset.repeat(2) | |||
| expected_output = [[[2], [6]], | |||
| [[3], [7]], | |||
| [[0], [4], [8]], | |||
| [[1], [5], [9]], | |||
| [[2], [6]], | |||
| [[3], [7]], | |||
| [[0], [4], [8]], | |||
| [[1], [5], [9]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_multi_bucket_with_padding(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"]) | |||
| @@ -471,6 +499,7 @@ def test_bucket_batch_invalid_column(): | |||
| if __name__ == '__main__': | |||
| test_bucket_batch_invalid_input() | |||
| test_bucket_batch_multi_bucket_no_padding() | |||
| test_bucket_batch_multi_bucket_no_padding_repeat() | |||
| test_bucket_batch_multi_bucket_with_padding() | |||
| test_bucket_batch_single_bucket_no_padding() | |||
| test_bucket_batch_single_bucket_with_padding() | |||
| @@ -406,7 +406,7 @@ def test_cifar_usage(): | |||
| try: | |||
| data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage) | |||
| num_rows = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_rows += 1 | |||
| except (ValueError, TypeError, RuntimeError) as e: | |||
| return str(e) | |||
| @@ -240,7 +240,7 @@ def test_mnist_usage(): | |||
| try: | |||
| data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False) | |||
| num_rows = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_rows += 1 | |||
| except (ValueError, TypeError, RuntimeError) as e: | |||
| return str(e) | |||
| @@ -424,7 +424,7 @@ def generator_big(maxid=20): | |||
| # test with row_data_buffer > 1 | |||
| def test_filter_by_generator_Partial(): | |||
| dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"]) | |||
| dataset = ds.GeneratorDataset(source=(lambda: generator_mc(99)), column_names=["col1", "col2"]) | |||
| dataset_s = dataset.shuffle(4) | |||
| dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) | |||
| @@ -502,7 +502,7 @@ def test_celeba_padded(): | |||
| data = data.repeat(2) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count = count + 1 | |||
| assert count == 2 | |||