From: @hfarahat Reviewed-by: Signed-off-by:pull/15813/MERGE
| @@ -37,22 +37,26 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false); | |||
| } | |||
| *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_, | |||
| children_start_end_index_); | |||
| *ptr = std::make_shared<ConcatOp>(builder_sampler_, children_flag_and_nums_, children_start_end_index_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the ConcatOp. | |||
| ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler, | |||
| ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler, | |||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||
| const std::vector<std::pair<int, int>> &children_start_end_index) | |||
| : PipelineOp(op_connector_size), | |||
| children_num_(0), | |||
| sampler_(sampler), | |||
| children_flag_and_nums_(children_flag_and_nums), | |||
| children_start_end_index_(children_start_end_index) {} | |||
| : ConcatOp() { | |||
| children_flag_and_nums_ = children_flag_and_nums; | |||
| children_start_end_index_ = children_start_end_index; | |||
| std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler); | |||
| if (distribute_sampler != nullptr) { | |||
| num_shard_ = distribute_sampler->GetDeviceNum(); | |||
| shard_index_ = distribute_sampler->GetDeviceID(); | |||
| } | |||
| } | |||
| ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} | |||
| ConcatOp::ConcatOp() | |||
| : PipelineOp(0), cur_child_(0), verified_(false), num_shard_(1), shard_index_(0), sample_number_(0) {} | |||
| // A function that prints info about the Operator | |||
| void ConcatOp::Print(std::ostream &out, bool show_all) const { | |||
| @@ -65,98 +69,16 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nDatasets: " << children_num_ << "\n\n"; | |||
| out << "\nDatasets: " << child_.size() << "\n\n"; | |||
| } | |||
| } | |||
| // This definition is added to pass the cyclomatic complexity rule of <= 20 units | |||
| // The NOLINT directive is to disable cpplint check. | |||
| // Clang format and cpplint give conflicting recommendations on this line below. | |||
| #define f(fv, sv, shard_index) \ | |||
| (((fv) == -1 && (sv) == -1) || ((fv) < (sv) && (shard_index) >= (fv) && (shard_index) < (sv)) || \ | |||
| ((fv) > (sv) && ((shard_index) >= (fv) || (shard_index) < (sv)))) // NOLINT | |||
| // Main entry point for Concat | |||
| Status ConcatOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| children_num_ = static_cast<int32_t>(child_.size()); | |||
| for (int32_t i = 0; i < children_num_; i++) { | |||
| children_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i)); | |||
| } | |||
| TensorRow new_row; | |||
| int eof_count = 0; | |||
| int sample_number = 0; | |||
| bool is_not_mappable = true; | |||
| bool is_not_mappable_or_second_ne_zero = true; | |||
| int num_shard = 1; | |||
| int shard_index = 0; | |||
| std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_); | |||
| if (distribute_sampler != nullptr) { | |||
| num_shard = distribute_sampler->GetDeviceNum(); | |||
| shard_index = distribute_sampler->GetDeviceID(); | |||
| } | |||
| while (eof_count == 0) { | |||
| for (int i = 0; i < children_num_; i++) { | |||
| // 1. Read the first row | |||
| RETURN_IF_NOT_OK(children_iterators_[i]->FetchNextTensorRow(&new_row)); | |||
| if (new_row.eof()) { | |||
| eof_count++; | |||
| continue; | |||
| } | |||
| // 2. Do verification as for column name, column data type and rank of column data | |||
| if (!new_row.eoe()) { | |||
| RETURN_IF_NOT_OK(Verify(i, new_row)); | |||
| } | |||
| // 3. Put the data into output_connector | |||
| if (!children_flag_and_nums_.empty()) { | |||
| is_not_mappable = children_flag_and_nums_[i].first; | |||
| is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[i].second); | |||
| } | |||
| while (!new_row.eoe() && !new_row.eof()) { | |||
| // if dataset is not mappable or generator dataset which source is yield, cannot get the number of samples in | |||
| // python layer), we use filtering to get data | |||
| if (sample_number % num_shard == shard_index && is_not_mappable_or_second_ne_zero) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row))); | |||
| } else if (!is_not_mappable_or_second_ne_zero) { | |||
| // if dataset is mappable or generator dataset which source is not yield, | |||
| // get the start and end subscripts of valid values | |||
| int fv = children_start_end_index_[i].first, sv = children_start_end_index_[i].second; | |||
| // determine whether the data allocated to the current shard id is false data | |||
| if (f(fv, sv, shard_index)) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row))); | |||
| } | |||
| } | |||
| // if dataset is not mappable or generator dataset which source is yield, sample_number+=1 | |||
| if (is_not_mappable_or_second_ne_zero) { | |||
| sample_number++; | |||
| } | |||
| RETURN_IF_NOT_OK(children_iterators_[i]->FetchNextTensorRow(&new_row)); | |||
| } | |||
| // if dataset is mappable,We don't use filtering to pick data. | |||
| // so sample_number plus the length of the entire dataset | |||
| if (!is_not_mappable_or_second_ne_zero) { | |||
| sample_number += children_flag_and_nums_[i].second; | |||
| } | |||
| } | |||
| // 4. Add eoe row after get rows from all child | |||
| if (eof_count == 0) { | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOE()); | |||
| } | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, | |||
| "Something went wrong, eof count does not match the number of children."); | |||
| // 5. Add eof row in the end manually | |||
| MS_LOG(DEBUG) << "Add the eof row manually in the end."; | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOF()); | |||
| return Status::OK(); | |||
| } | |||
| #define f(fv, sv, shard_index) \ | |||
| ((fv == -1 && sv == -1) || (fv < sv && shard_index >= fv && shard_index < sv) || \ | |||
| (fv > sv && (shard_index >= fv || shard_index < sv))) // NOLINT | |||
| Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) { | |||
| if (id == 0) { | |||
| @@ -174,6 +96,7 @@ Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) { | |||
| } | |||
| } | |||
| } | |||
| verified_ = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -211,6 +134,101 @@ Status ConcatOp::GetNumClasses(int64_t *num_classes) { | |||
| *num_classes = max_num_classes; | |||
| return Status::OK(); | |||
| } | |||
| Status ConcatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| bool ConcatOp::IgnoreSample() { | |||
| bool is_not_mappable_or_second_ne_zero = true; | |||
| if (!children_flag_and_nums_.empty()) { | |||
| bool is_not_mappable = children_flag_and_nums_[cur_child_].first; | |||
| is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second); | |||
| } | |||
| bool ret = true; | |||
| if (sample_number_ % num_shard_ == shard_index_ && is_not_mappable_or_second_ne_zero) { | |||
| ret = false; | |||
| } else if (!is_not_mappable_or_second_ne_zero) { | |||
| // if dataset is mappable or generator dataset which source is not yield, | |||
| // get the start and end subscripts of valid values | |||
| int fv = children_start_end_index_[cur_child_].first, sv = children_start_end_index_[cur_child_].second; | |||
| // determine whether the data allocated to the current shard id is false data | |||
| if (f(fv, sv, shard_index_)) { | |||
| ret = false; | |||
| } | |||
| } | |||
| if (is_not_mappable_or_second_ne_zero) { | |||
| sample_number_++; | |||
| } | |||
| return ret; | |||
| } | |||
| Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) { | |||
| bool is_not_mappable_or_second_ne_zero = true; | |||
| if (!children_flag_and_nums_.empty()) { | |||
| bool is_not_mappable = children_flag_and_nums_[cur_child_].first; | |||
| is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second); | |||
| } | |||
| RETURN_IF_NOT_OK(child_[cur_child_]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| if (!row->eoe() && !row->eof()) { | |||
| if (!verified_) RETURN_IF_NOT_OK(Verify(cur_child_, *row)); | |||
| if (IgnoreSample()) { | |||
| RETURN_IF_NOT_OK(GetNextRow(row, worker_id, retry_if_eoe)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| if (row->eoe()) { | |||
| // if last child, send out eoe and reset epoch | |||
| if (cur_child_ == child_.size() - 1) { | |||
| // reset | |||
| cur_child_ = 0; | |||
| verified_ = false; | |||
| UpdateRepeatAndEpochCounter(); | |||
| return Status::OK(); | |||
| } | |||
| if (!is_not_mappable_or_second_ne_zero) { | |||
| sample_number_ += children_flag_and_nums_[cur_child_].second; | |||
| } | |||
| cur_child_++; | |||
| verified_ = false; | |||
| RETURN_IF_NOT_OK(GetNextRow(row, worker_id, retry_if_eoe)); | |||
| return Status::OK(); | |||
| } | |||
| if (row->eof()) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cur_child_ == 0, "Received an unexpected EOF."); | |||
| for (int32_t i = cur_child_ + 1; i < child_.size(); i++) { | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "Row must be an EOF."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| int32_t ConcatOp::num_consumers() const { | |||
| if (parent_.empty()) { | |||
| MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; | |||
| return 1; | |||
| } else if (parent_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| int32_t ConcatOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -70,9 +70,8 @@ class ConcatOp : public PipelineOp { | |||
| // Constructor of the ConcatOp. | |||
| // @note The builder class should be used to call it | |||
| // @param op_connector_size - connector size | |||
| explicit ConcatOp(int32_t op_connector_size); | |||
| ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler, | |||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||
| ConcatOp(); | |||
| ConcatOp(const std::shared_ptr<SamplerRT> &sampler, const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||
| const std::vector<std::pair<int, int>> &children_start_end_index); | |||
| // Destructor | |||
| @@ -111,18 +110,29 @@ class ConcatOp : public PipelineOp { | |||
| /// \return Status - The status code return | |||
| Status GetNumClasses(int64_t *num_classes) override; | |||
| Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; | |||
| int32_t num_consumers() const override; | |||
| int32_t num_producers() const override; | |||
| /// Check if the current sample will be taken or dropped | |||
| /// \return bool | |||
| bool IgnoreSample(); | |||
| private: | |||
| Status Verify(int32_t id, const TensorRow &tensor_row); | |||
| int32_t children_num_; // The num of child of parent node. | |||
| std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name | |||
| std::vector<DataType> data_type_; | |||
| std::vector<dsize_t> data_rank_; | |||
| std::shared_ptr<SamplerRT> sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||
| std::vector<std::unique_ptr<ChildIterator>> children_iterators_; // Iterator for fetching. | |||
| int32_t cur_child_; | |||
| bool verified_; | |||
| int64_t sample_number_; | |||
| int32_t num_shard_; | |||
| int32_t shard_index_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -56,21 +56,52 @@ Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) { | |||
| FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size, | |||
| std::shared_ptr<TensorOp> predicate_func) | |||
| : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} | |||
| Status FilterOp::operator()() { | |||
| : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) { | |||
| worker_queues_.Init(num_workers, op_queue_size); | |||
| } | |||
| Status FilterOp::LaunchThreadsAndInitOp() { | |||
| // The operator class just starts off threads by calling the tree_ function. | |||
| if (tree_ == nullptr) { | |||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set."); | |||
| } | |||
| filter_queues_.Init(num_workers_, oc_queue_size_); | |||
| RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); | |||
| Status rc = | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id()); | |||
| RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->AllTasks()->CreateAsyncTask("FilterCollector", std::bind(&FilterOp::Collector, this), nullptr, id())); | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::operator()() { | |||
| // Synchronize with TaskManager. | |||
| Status rc = LaunchThreadsAndInitOp(); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| RETURN_IF_NOT_OK(Collector()); | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| int64_t cnt = 0; | |||
| while (child_iterator_->eof_handled() == false) { | |||
| while (new_row.empty() == false) { | |||
| RETURN_IF_NOT_OK(worker_queues_[cnt % num_workers_]->EmplaceBack(new_row)); | |||
| cnt++; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| } | |||
| RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOE)))); | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| } | |||
| RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOF)))); | |||
| // EOF received, send quit signal to all workers | |||
| for (int32_t ind = 0; ind < num_workers_; ind++) { | |||
| RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagQuit)))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -110,36 +141,30 @@ void FilterOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| Status FilterOp::WorkerEntry(int32_t worker_id) { | |||
| std::unique_ptr<ChildIterator> child_iterator = std::make_unique<ChildIterator>(this, worker_id, 0); | |||
| // Handshake with TaskManager that thread creation is successful. | |||
| TaskManager::FindMe()->Post(); | |||
| bool worker_stop = false; | |||
| while (worker_stop == false) { | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row)); | |||
| while (!new_row.quit()) { | |||
| // Getting a TensorRow to work on. | |||
| TensorRow in_row; | |||
| RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&in_row)); | |||
| if (in_row.eoe()) { | |||
| RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(in_row, filterCtrl::kFilterEoe))); | |||
| continue; | |||
| } else if (in_row.eof()) { | |||
| RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(in_row, filterCtrl::kFilterEof))); | |||
| worker_stop = true; | |||
| continue; | |||
| if (new_row.eoe()) { | |||
| RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEoe))); | |||
| } else if (new_row.eof()) { | |||
| RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEof))); | |||
| } else { | |||
| RETURN_IF_NOT_OK(ValidateInColumns(in_columns_)); | |||
| bool result; | |||
| RETURN_IF_NOT_OK(WorkerCompute(new_row, &result)); | |||
| if (result) | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterFull))); | |||
| else | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterEmpty))); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateInColumns(in_columns_)); | |||
| bool result; | |||
| RETURN_IF_NOT_OK(WorkerCompute(in_row, &result)); | |||
| if (result) | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_row), filterCtrl::kFilterFull))); | |||
| else | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_row), filterCtrl::kFilterEmpty))); | |||
| RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -160,15 +185,16 @@ Status FilterOp::WorkerCompute(const TensorRow &in_row, bool *out_predicate) { | |||
| // if the filtered TensorRow is written directly to out_connector_, | |||
| // the thread fetching data will block in a queue. | |||
| // Collector function will reorder the TensorRow in order. | |||
| // Collector thread will reorder the TensorRow in order until EOF is received | |||
| // for example in two work queues: | |||
| // int filter_queues_: | |||
| // queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) | |||
| // queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) | |||
| // queue1: TR(data1 kFilterEmpty) TR(eoe) TR(data4) TR(eof) | |||
| // queue2: TR(data2) TR(data3 kFilterEmpty) TR(eoe) | |||
| // after reorder in out_connector_: | |||
| // queue1: DB(data2) DB(data4) DB(eof) | |||
| // queue2: DB(eoe) DB(eoe) | |||
| // queue1: TR(data2) TR(data4) TR(eof) | |||
| // queue2: TR(eoe) TR(eoe) | |||
| Status FilterOp::Collector() { | |||
| TaskManager::FindMe()->Post(); | |||
| bool collector_stop = false; | |||
| uint64_t task_id_cnt = 0; | |||
| uint64_t out_id_cnt = 0; | |||
| @@ -216,6 +242,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate | |||
| return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed"); | |||
| } | |||
| int32_t FilterOp::num_consumers() const { return 1; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -126,6 +126,8 @@ class FilterOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kFilterOp; } | |||
| int32_t num_consumers() const override; | |||
| private: | |||
| // predicate_func python callable which returns a boolean value. | |||
| std::shared_ptr<TensorOp> predicate_func_; | |||
| @@ -136,6 +138,10 @@ class FilterOp : public ParallelOp { | |||
| // Internal queue for filter. | |||
| QueueList<std::pair<TensorRow, filterCtrl>> filter_queues_; | |||
| QueueList<TensorRow> worker_queues_; // internal queue for syncing worker | |||
| std::unique_ptr<ChildIterator> child_iterator_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of FilterOp, getting the data from previous Op, validating user specified column names, | |||
| // applying predicate to each of the data, filter the data when predicate result is false. | |||
| @@ -168,6 +174,10 @@ class FilterOp : public ParallelOp { | |||
| // @param input_columns The vector of input column names used in the current thread. | |||
| // @return Status The status code returned | |||
| Status ValidateInColumns(const std::vector<std::string> &input_columns); | |||
| // Do the initialization of all queues then start all worker threads | |||
| // @return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| }; | |||
| } // namespace dataset | |||
| @@ -43,42 +43,28 @@ Status RenameOp::Builder::SanityCheck() const { return Status::OK(); } | |||
| // build method for RenameOp | |||
| Status RenameOp::Builder::Build(std::shared_ptr<RenameOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<RenameOp>(builder_in_columns_, builder_out_columns_, builder_op_connector_size_); | |||
| *ptr = std::make_shared<RenameOp>(builder_in_columns_, builder_out_columns_); | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | |||
| int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {} | |||
| RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names) | |||
| : PipelineOp(0), in_columns_(in_col_names), out_columns_(out_col_names) {} | |||
| // destructor | |||
| RenameOp::~RenameOp() {} | |||
| // main entry point for rename | |||
| Status RenameOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| while (!new_row.eof()) { | |||
| while (!new_row.eoe()) { | |||
| MS_LOG(DEBUG) << "Rename operator pushing next row."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row))); | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOE()); | |||
| MS_LOG(DEBUG) << "Rename operator EOE Received."; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| MS_LOG(DEBUG) << "Rename operator fetching row after EOE."; | |||
| // Gets a row from the child operator and projects the row. | |||
| Status RenameOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| if (row->eoe()) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOF()); | |||
| MS_LOG(DEBUG) << "Rename operator EOF Received."; | |||
| return Status::OK(); | |||
| } | |||
| Status RenameOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RenameOp is an inlined operator."); } | |||
| // Rename core functionality to compute the new column name id map. | |||
| // We need to overwrite the super class ComputeColMap here because we're making a modification of the | |||
| // map from the child map. | |||
| @@ -151,15 +137,25 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t | |||
| } | |||
| } | |||
| Status RenameOp::EofReceived(int32_t) { | |||
| MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now."; | |||
| return Status::OK(); | |||
| int32_t RenameOp::num_consumers() const { | |||
| if (parent_.empty()) { | |||
| MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1."; | |||
| return 1; | |||
| } else if (parent_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| Status RenameOp::EoeReceived(int32_t) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| int32_t RenameOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -80,17 +80,11 @@ class RenameOp : public PipelineOp { | |||
| // @param in_col_names names of columns to rename | |||
| // @param out_col_names names of columns after rename | |||
| // @param op_connector_size connector size | |||
| RenameOp(const std::vector<std::string> &in_col_names, // In: Col names to consume | |||
| const std::vector<std::string> &out_col_names, // In: Col names to produce | |||
| int32_t op_connector_size); | |||
| RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names); | |||
| // Destructor | |||
| ~RenameOp(); | |||
| Status EofReceived(int32_t) override; | |||
| Status EoeReceived(int32_t) override; | |||
| // Print function for Rename | |||
| // @param out output stream to print to | |||
| // @param show_all if it should print everything | |||
| @@ -112,6 +106,13 @@ class RenameOp : public PipelineOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kRenameOp; } | |||
| // Gets a row from the child node and projects that row. The caller is typically our parent node. | |||
| // @param row - output pointer to the projected row. | |||
| // @param worker_id - The worker id | |||
| Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; | |||
| int32_t num_consumers() const override; | |||
| int32_t num_producers() const override; | |||
| protected: | |||
| // Rename core functionality | |||
| // Computing the assignment of the new column name map. | |||
| @@ -43,13 +43,12 @@ Status SkipOp::Builder::SanityCheck() const { | |||
| // The builder "build" method creates the final object. | |||
| Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the SkipOp. | |||
| SkipOp::SkipOp(int32_t count, int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} | |||
| SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} | |||
| // Destructor | |||
| SkipOp::~SkipOp() {} | |||
| @@ -69,34 +68,48 @@ void SkipOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| } | |||
| // main entry point for skip | |||
| Status SkipOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||
| Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| while (!new_row.eof()) { | |||
| // Reset count | |||
| skip_count_ = 0; | |||
| while (!new_row.eoe()) { | |||
| // Drop first count rows | |||
| if (skip_count_ < max_skips_) { | |||
| skip_count_++; | |||
| } else { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row))); | |||
| } | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) { | |||
| bool eoe_received = false; | |||
| while (skip_count_ < max_skips_) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| if (row->eoe()) { | |||
| eoe_received = true; | |||
| break; | |||
| } | |||
| // we got eoe, now try again until we got eof | |||
| MS_LOG(DEBUG) << "Skip operator EOE Received."; | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOE()); | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| skip_count_++; | |||
| } | |||
| if (!eoe_received) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| } | |||
| if (row->eoe()) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| skip_count_ = 0; | |||
| } | |||
| MS_LOG(DEBUG) << "Skip operator EOF Received."; | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOF()); | |||
| return Status::OK(); | |||
| } | |||
| int32_t SkipOp::num_consumers() const { | |||
| if (parent_.empty()) { | |||
| MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; | |||
| return 1; | |||
| } else if (parent_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| int32_t SkipOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -50,7 +50,7 @@ class SkipOp : public PipelineOp { | |||
| // Constructor of the SkipOp. | |||
| // @note The builder class should be used to call it | |||
| // @param count - The number of skips to do | |||
| explicit SkipOp(int32_t count, int32_t op_connector_size); | |||
| explicit SkipOp(int32_t count); | |||
| // Destructor | |||
| ~SkipOp(); | |||
| @@ -69,6 +69,9 @@ class SkipOp : public PipelineOp { | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kSkipOp; } | |||
| Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; | |||
| int32_t num_consumers() const override; | |||
| int32_t num_producers() const override; | |||
| private: | |||
| int32_t max_skips_; // The number of skips that the user requested | |||
| @@ -43,13 +43,12 @@ Status TakeOp::Builder::SanityCheck() const { | |||
| // The builder "build" method creates the final object. | |||
| Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_); | |||
| *ptr = std::make_shared<TakeOp>(build_max_takes_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the TakeOp. | |||
| TakeOp::TakeOp(int32_t count, int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} | |||
| TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} | |||
| // A print method typically used for debugging | |||
| void TakeOp::Print(std::ostream &out, bool show_all) const { | |||
| @@ -66,37 +65,53 @@ void TakeOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| } | |||
| // Main entry point for Take | |||
| Status TakeOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||
| Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| while (!new_row.eof()) { | |||
| while (!new_row.eoe()) { | |||
| if (take_count_ < max_takes_) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row))); | |||
| take_count_++; | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| } | |||
| if (take_count_ == max_takes_) { | |||
| RETURN_IF_NOT_OK(child_iterator_->Drain()); | |||
| break; | |||
| } | |||
| Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) { | |||
| bool eoe_received = false; | |||
| if (take_count_ < max_takes_) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| if (row->eoe()) { | |||
| eoe_received = true; | |||
| } else { | |||
| take_count_++; | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| if (take_count_ == max_takes_) { | |||
| // drain | |||
| while (!row->eoe()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe)); | |||
| } | |||
| eoe_received = true; | |||
| } | |||
| if (eoe_received) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| take_count_ = 0; | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOE()); | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||
| } | |||
| take_count_ = 0; | |||
| MS_LOG(DEBUG) << "Meet the end and push-back eof row."; | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOF()); | |||
| return Status::OK(); | |||
| } | |||
| int32_t TakeOp::num_consumers() const { | |||
| if (parent_.empty()) { | |||
| MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; | |||
| return 1; | |||
| } else if (parent_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| int32_t TakeOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -53,7 +53,7 @@ class TakeOp : public PipelineOp { | |||
| // Constructor of the TakeOp. | |||
| // @note The builder class should be used to call it | |||
| // @param count - The number of takes to do | |||
| explicit TakeOp(int32_t count, int32_t op_connector_size); | |||
| explicit TakeOp(int32_t count); | |||
| // Destructor | |||
| ~TakeOp() = default; | |||
| @@ -82,6 +82,10 @@ class TakeOp : public PipelineOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kTakeOp; } | |||
| Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; | |||
| int32_t num_consumers() const override; | |||
| int32_t num_producers() const override; | |||
| private: | |||
| int32_t max_takes_; // The number of takes that the user requested | |||
| int32_t take_count_; // A counter for the current number of executed takes | |||
| @@ -45,106 +45,40 @@ Status ZipOp::Builder::Build(std::shared_ptr<ZipOp> *ptr) { | |||
| } | |||
| // Construct ZipOp here, local variables initialized in operator due to tree construction restrictions | |||
| ZipOp::ZipOp(int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), children_num_(0), draining_(false), eof_(false) {} | |||
| ZipOp::ZipOp(int32_t op_connector_size) : PipelineOp(0) {} | |||
| // destructor | |||
| ZipOp::~ZipOp() {} | |||
| // Entry point for Zip, called by launch() | |||
| Status ZipOp::operator()() { | |||
| // The children_num_ parameter needs to be put here | |||
| children_num_ = child_.size(); | |||
| // Synchronize with TaskManager once the thread is created. | |||
| TaskManager::FindMe()->Post(); | |||
| // initialize the iterators | |||
| for (int32_t i = 0; i < children_num_; ++i) { | |||
| // magic number 0 since Zip is not a parallel Op | |||
| child_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i)); | |||
| } | |||
| // Loop until eof is true | |||
| while (!eof_) { | |||
| // 1 Prepare new epoch | |||
| RETURN_IF_NOT_OK(prepare()); | |||
| // 2 fetch first row | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(getNextTensorRow(&row)); | |||
| // If an eof got picked up, then we're done | |||
| if (eof_) { | |||
| break; | |||
| } | |||
| while (!draining_) { | |||
| // 3 send new row to the out connector | |||
| MS_LOG(DEBUG) << "Zip operator finished one row, pushing, cols " << row.size() << ", map " | |||
| << column_name_id_map_.size() << "."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(std::move(row))); | |||
| // 4 fetch one more row | |||
| RETURN_IF_NOT_OK(getNextTensorRow(&row)); | |||
| } | |||
| // 5 handle drain state. | |||
| if (draining_) { | |||
| MS_LOG(DEBUG) << "Zip operator is now draining child inputs."; | |||
| RETURN_IF_NOT_OK(drainPipeline()); | |||
| // Now that we have drained child inputs, send the eoe up. | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOE()); | |||
| } | |||
| } | |||
| // 6 handle eof | |||
| MS_LOG(DEBUG) << "Zip operator got EOF, propagating."; | |||
| RETURN_IF_NOT_OK(out_connector_->SendEOF()); | |||
| return Status::OK(); | |||
| } | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| Status ZipOp::prepare() { | |||
| MS_LOG(DEBUG) << "Zip operator prepares for new epoch."; | |||
| draining_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // fetches next zipped (merged) row | |||
| Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) { | |||
| Status ZipOp::getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, int32_t worker_id, | |||
| bool retry_if_eoe) { | |||
| *new_zip_row = {}; | |||
| // iterate over all iterators and generate a row | |||
| for (int32_t i = 0; i < children_num_; ++i) { | |||
| TensorRow new_row = {}; | |||
| RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row)); | |||
| // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row | |||
| if (new_row.empty()) { | |||
| // If we did not get a row from any of the children, then it's the end of an epoch and we can move | |||
| // to drain state. | |||
| MS_LOG(DEBUG) << "Zip operator child iterator produced empty row."; | |||
| draining_ = true; | |||
| new_zip_row->clear(); | |||
| // If we picked up an eof here, then we are completely done. | |||
| if ((child_iterators_[i])->eof_handled()) { | |||
| MS_LOG(DEBUG) << "Zip operator iterator got EOF."; | |||
| eof_ = true; | |||
| } | |||
| for (int32_t i = 0; i < child_.size(); ++i) { | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextRow(&new_row, worker_id, retry_if_eoe)); | |||
| if (new_row.eoe() || new_row.eof()) { | |||
| *new_zip_row = new_row; | |||
| *skip_child = i; | |||
| return Status::OK(); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << "."; | |||
| // if row isn't empty then we can append the fetched row with new_zip_row | |||
| new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end()); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << "."; | |||
| return Status::OK(); | |||
| } | |||
| // drain end of epoch messages from iterator for this epoch | |||
| Status ZipOp::drainPipeline() { | |||
| // we don't need to drain if we reached eof | |||
| if (eof_) { | |||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, | |||
| "ZipOp draining should not be done if already at eof!"); | |||
| } | |||
| for (int32_t con = 0; con < children_num_; ++con) { | |||
| Status ZipOp::drainPipeline(int32_t skip_child, int32_t worker_id, bool retry_if_eoe) { | |||
| for (int32_t con = 0; con < child_.size(); ++con) { | |||
| if (con == skip_child) continue; | |||
| MS_LOG(DEBUG) << "Zip operator draining child at " << con << "."; | |||
| RETURN_IF_NOT_OK(child_iterators_[con]->Drain()); | |||
| TensorRow row; | |||
| while (!row.eoe()) { | |||
| RETURN_IF_NOT_OK(child_[con]->GetNextRow(&row, worker_id, retry_if_eoe)); | |||
| } | |||
| } | |||
| // at this point all connectors don't contain end of epoch messages. next iteration should be clean | |||
| return Status::OK(); | |||
| @@ -161,9 +95,9 @@ void ZipOp::Print(std::ostream &out, // In: The output stream to print to | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nDatasets: " << children_num_ << "\n\n"; | |||
| } | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nDatasets: " << child_.size() << "\n\n"; | |||
| } | |||
| // overwrite function and handle eof | |||
| @@ -202,5 +136,39 @@ Status ZipOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ZipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) { | |||
| int32_t skip_child = -1; | |||
| RETURN_IF_NOT_OK(getNextZippedRow(row, &skip_child, worker_id, retry_if_eoe)); | |||
| if (row->eoe()) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| MS_LOG(DEBUG) << "Zip operator is now draining child inputs."; | |||
| RETURN_IF_NOT_OK(drainPipeline(skip_child, worker_id, retry_if_eoe)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| int32_t ZipOp::num_consumers() const { | |||
| if (parent_.empty()) { | |||
| MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; | |||
| return 1; | |||
| } else if (parent_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| int32_t ZipOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -104,15 +104,16 @@ class ZipOp : public PipelineOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kZipOp; } | |||
| private: | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| Status prepare(); | |||
| Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; | |||
| int32_t num_consumers() const override; | |||
| int32_t num_producers() const override; | |||
| private: | |||
| // Special handle case where an empty row has been received from child iterator | |||
| // @note - we need to drain eoe signals from all children connectors. | |||
| // @details - when this function is called, then we encountered eoe at child iterator | |||
| // we have to drain rows from other child iterators until we hit eoe from all other child iterators | |||
| Status drainPipeline(); | |||
| Status drainPipeline(int32_t skip_child, int32_t worker_id, bool retry_if_eoe); | |||
| // Merges 1 row from each childIterator together | |||
| // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty | |||
| @@ -125,16 +126,11 @@ class ZipOp : public PipelineOp { | |||
| // 1 a T | |||
| // \ | / | |||
| // 1, a, T | |||
| Status getNextTensorRow(TensorRow *const new_zip_row); | |||
| Status getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, int32_t worker_id, bool retry_if_eoe); | |||
| // Computing the assignment of the column name map. | |||
| // @return - Status | |||
| Status ComputeColMap() override; | |||
| int32_t children_num_; | |||
| bool draining_; | |||
| bool eof_; | |||
| std::vector<std::unique_ptr<ChildIterator>> child_iterators_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -123,12 +123,11 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| std::shared_ptr<ConcatOp> op; | |||
| if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | |||
| op = std::make_shared<ConcatOp>(connector_que_size_); | |||
| op = std::make_shared<ConcatOp>(); | |||
| } else { | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| op = | |||
| std::make_shared<ConcatOp>(connector_que_size_, sampler_rt, children_flag_and_nums_, children_start_end_index_); | |||
| op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_); | |||
| } | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| @@ -58,7 +58,7 @@ Status RenameNode::ValidateParams() { | |||
| } | |||
| Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_); | |||
| auto op = std::make_shared<RenameOp>(input_columns_, output_columns_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| @@ -39,7 +39,7 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + | |||
| // Function to build the SkipOp | |||
| Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto op = std::make_shared<SkipOp>(skip_count_, connector_que_size_); | |||
| auto op = std::make_shared<SkipOp>(skip_count_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| @@ -40,7 +40,7 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s | |||
| // Function to build the TakeOp | |||
| Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto op = std::make_shared<TakeOp>(take_count_, connector_que_size_); | |||
| auto op = std::make_shared<TakeOp>(take_count_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| @@ -44,7 +44,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // SkipOp | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2); | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5); | |||
| rc = my_tree->AssociateNode(skip_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -85,7 +85,7 @@ def test_profiling_complex_pipeline(): | |||
| data = json.load(f) | |||
| op_info = data["op_info"] | |||
| assert len(op_info) == 5 | |||
| for i in range(5): | |||
| for i in range(4): | |||
| assert "size" in op_info[i]["metrics"]["output_queue"] | |||
| assert "length" in op_info[i]["metrics"]["output_queue"] | |||
| assert "throughput" in op_info[i]["metrics"]["output_queue"] | |||