You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

project_op.cc 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/project_op.h"
  17. #include <algorithm>
  18. #include <iomanip>
  19. #include <iostream>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <utility>
  23. #include <vector>
  24. #include "dataset/engine/data_buffer.h"
  25. #include "dataset/engine/db_connector.h"
  26. #include "dataset/engine/execution_tree.h"
  27. #include "dataset/engine/opt/pass.h"
  28. #include "utils/log_adapter.h"
  29. namespace mindspore {
  30. namespace dataset {
  31. ProjectOp::Builder::Builder(const std::vector<std::string> &columns_to_project)
  32. : builder_columns_to_project_(columns_to_project) {}
  33. Status ProjectOp::Builder::SanityCheck() const {
  34. if (builder_columns_to_project_.empty()) {
  35. std::string err_msg("Columns to project is empty.");
  36. RETURN_STATUS_UNEXPECTED(err_msg);
  37. }
  38. return Status::OK();
  39. }
  40. Status ProjectOp::Builder::Build(std::shared_ptr<ProjectOp> *ptr) {
  41. RETURN_IF_NOT_OK(SanityCheck());
  42. *ptr = std::make_shared<ProjectOp>(builder_columns_to_project_);
  43. return Status::OK();
  44. }
  45. ProjectOp::ProjectOp(const std::vector<std::string> &columns_to_project)
  46. : PipelineOp(0), columns_to_project_(columns_to_project) {}
  47. void ProjectOp::Print(std::ostream &out, bool show_all) const {
  48. // Always show the id and name as first line regardless if this summary or detailed print
  49. out << "(" << std::setw(2) << operator_id_ << ") <ProjectOp>:";
  50. if (!show_all) {
  51. // Call the super class for displaying any common 1-liner info
  52. PipelineOp::Print(out, show_all);
  53. // Then show any custom derived-internal 1-liner info for this op
  54. out << "\n";
  55. } else {
  56. // Call the super class for displaying any common detailed info
  57. PipelineOp::Print(out, show_all);
  58. // Then show any custom derived-internal stuff
  59. out << "\nColumns that are projected:";
  60. for (size_t i = 0; i < columns_to_project_.size(); i++) {
  61. out << "\n" << columns_to_project_[i];
  62. }
  63. out << "\n\n";
  64. }
  65. }
  66. // Gets a buffer from the child operator and projects the buffer.
  67. Status ProjectOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  68. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id, retry_if_eoe));
  69. if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) {
  70. RETURN_IF_NOT_OK(Project(p_buffer));
  71. }
  72. return Status::OK();
  73. }
  74. Status ProjectOp::Project(std::unique_ptr<DataBuffer> *data_buffer) {
  75. std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>();
  76. while ((*data_buffer)->NumRows() > 0) {
  77. TensorRow current_row;
  78. RETURN_IF_NOT_OK((*data_buffer)->PopRow(&current_row));
  79. TensorRow new_row;
  80. (void)std::transform(projected_column_indices_.begin(), projected_column_indices_.end(),
  81. std::back_inserter(new_row), [&current_row](uint32_t x) { return current_row[x]; });
  82. new_tensor_table->push_back(new_row);
  83. }
  84. (*data_buffer)->set_tensor_table(std::move(new_tensor_table));
  85. return Status::OK();
  86. }
  87. // Class functor operator () override.
  88. // Most dataset ops operate by launching a thread (see ExecutionTree).
  89. // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the
  90. // functor since this op runs inlined inside another operator. The function is overloaded to
  91. // ensure that it is not called by mistake (it will generate an error).
  92. Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
  93. int32_t ProjectOp::num_consumers() const {
  94. if (parent_.empty()) {
  95. MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1.";
  96. return 1;
  97. } else if (parent_[0] == nullptr) {
  98. MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0.";
  99. return 0;
  100. } else {
  101. return parent_[0]->num_consumers();
  102. }
  103. }
  104. int32_t ProjectOp::num_producers() const {
  105. if (child_.empty() || child_[0] == nullptr) {
  106. MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0.";
  107. return 0;
  108. } else {
  109. return child_[0]->num_producers();
  110. }
  111. }
  112. Status ProjectOp::EoeReceived(int32_t worker_id) {
  113. state_ = OpState::kDeOpIdle;
  114. return Status::OK();
  115. }
  116. Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
  117. // Visitor accept method for NodePass
  118. Status ProjectOp::Accept(NodePass *p, bool *modified) {
  119. // Downcast shared pointer then call visitor
  120. return p->RunOnNode(shared_from_base<ProjectOp>(), modified);
  121. }
  122. // Compute the column map and save it into our own column name map
  123. // We cannot use the super class ComputeColMap here because we're making a modification of the
  124. // map from the child map.
  125. Status ProjectOp::ComputeColMap() {
  126. if (column_name_id_map_.empty()) {
  127. std::unordered_map<std::string, int32_t> child_column_name_mapping = child_[0]->column_name_id_map();
  128. for (size_t i = 0; i < columns_to_project_.size(); i++) {
  129. std::string &current_column = columns_to_project_[i];
  130. if (child_column_name_mapping.find(current_column) == child_column_name_mapping.end()) {
  131. std::string err_msg = "ProjectOp: column " + current_column + " does not exist in child operator.";
  132. RETURN_STATUS_UNEXPECTED(err_msg);
  133. }
  134. // Setup the new column name mapping for ourself (base class field)
  135. column_name_id_map_[current_column] = i;
  136. projected_column_indices_.push_back(child_column_name_mapping[current_column]);
  137. }
  138. } else {
  139. MS_LOG(WARNING) << "Column name map is already set!";
  140. }
  141. return Status::OK();
  142. }
  143. } // namespace dataset
  144. } // namespace mindspore