/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "common/utils.h" #include "dataset/core/config_manager.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/datasetops/concat_op.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. ConcatOp::Builder::Builder() { std::shared_ptr cfg = GlobalContext::config_manager(); builder_op_connector_size_ = cfg->op_connector_size(); } // The builder "build" method creates the final object. Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { *ptr = std::make_shared(builder_op_connector_size_); return Status::OK(); } // Constructor of the ConcatOp. ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} // A function that prints info about the Operator void ConcatOp::Print(std::ostream &out, bool show_all) const { // Always show the id and name as first line regardless if this is summary or detailed print out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); // Then show any custom derived-internal 1-liner info for this op out << "\n"; } else { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff out << "\nDatasets: " << children_num_ << "\n\n"; } } // Main entry point for Concat Status ConcatOp::operator()() { // The children_num_ parameter needs to be put here children_num_ = static_cast(child_.size()); TaskManager::FindMe()->Post(); std::unique_ptr buf; RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); int eof_count = 0; while (eof_count != children_num_) { for (int i = 0; i < children_num_; i++) { // 1. Throw the eof buffer when meet it if (buf->eof() || buf->eoe()) { RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); } // 2. Do verification as for column name, column data type and rank of column data RETURN_IF_NOT_OK(Verify(i, buf)); // 3. Put the data into output_connector while (!buf->eoe() && !buf->eof()) { RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); } // 4. Throw the eoe buffer when meet it if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); } // 5. Add eoe buffer after get buffer from all child if (i == (children_num_ - 1)) { auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); } if (buf->eof()) { eof_count++; } } } // 6. Add eof buffer in the end manually MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); return Status::OK(); } Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { TensorRow new_row; buf->GetRow(0, &new_row); if (id == 0) { // Obtain the data type and data rank in child[0] for (auto item : new_row) { data_type_.push_back(item->type()); data_rank_.push_back(item->Rank()); } } else { // Compare the data type and data rank with these in child[0] int32_t index = 0; for (auto item : new_row) { if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); } } } return Status::OK(); } Status ConcatOp::PrepareNodePostAction() { RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); tree_->AddToEOEOpStack(shared_from_this()); return Status::OK(); } // We need to overwrite the super class ComputeColMap here because the number of children is more than 1. Status ConcatOp::ComputeColMap() { if (column_name_id_map_.empty()) { // Obtain columns_name_id_map from child_[0] column_name_id_map_ = child_[0]->column_name_id_map(); if (column_name_id_map_.empty()) { RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); } // Verify all children have the same column name map for (int32_t i = 0; i < child_.size(); ++i) { if (child_[i]->column_name_id_map() != column_name_id_map_) { RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); } } } else { MS_LOG(WARNING) << "Column name map is already set!"; } return Status::OK(); } } // namespace dataset } // namespace mindspore