You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat_op.cc 5.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iomanip>
  17. #include <utility>
  18. #include "common/utils.h"
  19. #include "dataset/core/config_manager.h"
  20. #include "dataset/engine/data_buffer.h"
  21. #include "dataset/engine/datasetops/concat_op.h"
  22. #include "dataset/engine/db_connector.h"
  23. #include "dataset/engine/execution_tree.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // Builder constructor. Creates the builder object.
  27. ConcatOp::Builder::Builder() {
  28. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  29. builder_op_connector_size_ = cfg->op_connector_size();
  30. }
  31. // The builder "build" method creates the final object.
  32. Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
  33. *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_);
  34. return Status::OK();
  35. }
  36. // Constructor of the ConcatOp.
  37. ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {}
  38. // A function that prints info about the Operator
  39. void ConcatOp::Print(std::ostream &out, bool show_all) const {
  40. // Always show the id and name as first line regardless if this is summary or detailed print
  41. out << "(" << std::setw(2) << operator_id_ << ") <ConcatOp>:";
  42. if (!show_all) {
  43. // Call the super class for displaying any common 1-liner info
  44. PipelineOp::Print(out, show_all);
  45. // Then show any custom derived-internal 1-liner info for this op
  46. out << "\n";
  47. } else {
  48. // Call the super class for displaying any common detailed info
  49. PipelineOp::Print(out, show_all);
  50. // Then show any custom derived-internal stuff
  51. out << "\nDatasets: " << children_num_ << "\n\n";
  52. }
  53. }
  54. // Main entry point for Concat
  55. Status ConcatOp::operator()() {
  56. // The children_num_ parameter needs to be put here
  57. children_num_ = static_cast<int32_t>(child_.size());
  58. TaskManager::FindMe()->Post();
  59. std::unique_ptr<DataBuffer> buf;
  60. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
  61. int eof_count = 0;
  62. while (eof_count != children_num_) {
  63. for (int i = 0; i < children_num_; i++) {
  64. // 1. Throw the eof buffer when meet it
  65. if (buf->eof() || buf->eoe()) {
  66. RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
  67. }
  68. // 2. Do verification as for column name, column data type and rank of column data
  69. RETURN_IF_NOT_OK(Verify(i, buf));
  70. // 3. Put the data into output_connector
  71. while (!buf->eoe() && !buf->eof()) {
  72. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
  73. RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
  74. }
  75. // 4. Throw the eoe buffer when meet it
  76. if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) {
  77. RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
  78. }
  79. // 5. Add eoe buffer after get buffer from all child
  80. if (i == (children_num_ - 1)) {
  81. auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  82. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
  83. }
  84. if (buf->eof()) {
  85. eof_count++;
  86. }
  87. }
  88. }
  89. // 6. Add eof buffer in the end manually
  90. MS_LOG(DEBUG) << "Add the eof buffer manualy in the end.";
  91. auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
  92. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
  93. return Status::OK();
  94. }
  95. Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
  96. TensorRow new_row;
  97. buf->GetRow(0, &new_row);
  98. if (id == 0) {
  99. // Obtain the data type and data rank in child[0]
  100. for (auto item : new_row) {
  101. data_type_.push_back(item->type());
  102. data_rank_.push_back(item->Rank());
  103. }
  104. } else {
  105. // Compare the data type and data rank with these in child[0]
  106. int32_t index = 0;
  107. for (auto item : new_row) {
  108. if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) {
  109. RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset.");
  110. }
  111. }
  112. }
  113. return Status::OK();
  114. }
  115. Status ConcatOp::PrepareNodePostAction() {
  116. RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
  117. tree_->AddToEOEOpStack(shared_from_this());
  118. return Status::OK();
  119. }
  120. // We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
  121. Status ConcatOp::ComputeColMap() {
  122. if (column_name_id_map_.empty()) {
  123. // Obtain columns_name_id_map from child_[0]
  124. column_name_id_map_ = child_[0]->column_name_id_map();
  125. if (column_name_id_map_.empty()) {
  126. RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!");
  127. }
  128. // Verify all children have the same column name map
  129. for (int32_t i = 0; i < child_.size(); ++i) {
  130. if (child_[i]->column_name_id_map() != column_name_id_map_) {
  131. RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset.");
  132. }
  133. }
  134. } else {
  135. MS_LOG(WARNING) << "Column name map is already set!";
  136. }
  137. return Status::OK();
  138. }
  139. } // namespace dataset
  140. } // namespace mindspore