/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "dataset/core/config_manager.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" #include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { std::shared_ptr cfg = GlobalContext::config_manager(); builder_op_connector_size_ = cfg->op_connector_size(); } Status SkipOp::Builder::SanityCheck() const { if (build_max_skips_ < 0) { std::string err_msg("Skip count must be positive integer or 0."); RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } // The builder "build" method creates the final object. Status SkipOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); return Status::OK(); } // Constructor of the SkipOp. SkipOp::SkipOp(int32_t count, int32_t op_connector_size) : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} // Destructor SkipOp::~SkipOp() {} // A print method typically used for debugging void SkipOp::Print(std::ostream &out, bool show_all) const { // Always show the id and name as first line regardless if this summary or detailed print out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); // Then show any custom derived-internal 1-liner info for this op out << " [skips: " << max_skips_ << "]\n"; } else { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff out << "\nSkip count: " << skip_count_ << "\nMax skips: " << max_skips_ << "\n\n"; } } // Base-class override for handling cases when an eoe is received. Status SkipOp::EoeReceived(int32_t worker_id) { skip_count_ = 0; state_ = OpState::kDeOpIdle; return Status::OK(); } // main entry point for skip Status SkipOp::operator()() { TaskManager::FindMe()->Post(); std::unique_ptr curr_buffer; RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); while (curr_buffer->eof() == false) { // Reset count skip_count_ = 0; while (curr_buffer->eoe() == false) { // Drop first count rows while (skip_count_ < max_skips_) { if (curr_buffer->eoe() || curr_buffer->eof()) { break; } // Consider the rows of buffer more than one TensorRow drop_row; int row_num = curr_buffer->NumRows(); int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; skip_count_ += drop_num; for (int i = 0; i < drop_num; i++) { RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); } if (curr_buffer->NumRows() == 0) { RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); } } RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); } // we got eoe, now try again until we got eof MS_LOG(DEBUG) << "Skip operator EOE Received."; RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); } MS_LOG(DEBUG) << "Skip operator EOF Received."; RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); return Status::OK(); } // Base-class override for handling cases when an eof is received. Status SkipOp::EofReceived(int32_t worker_id) { MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; return Status::OK(); } // Visitor accept method for NodePass Status SkipOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore