You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

project_op.cc 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/project_op.h"
  17. #include <algorithm>
  18. #include <iostream>
  19. #include <string>
  20. #include <unordered_map>
  21. #include <utility>
  22. #include <vector>
  23. #include "dataset/engine/data_buffer.h"
  24. #include "dataset/engine/db_connector.h"
  25. #include "dataset/engine/execution_tree.h"
  26. #include "utils/log_adapter.h"
  27. namespace mindspore {
  28. namespace dataset {
  29. ProjectOp::Builder::Builder(const std::vector<std::string> &columns_to_project)
  30. : builder_columns_to_project_(columns_to_project) {}
  31. Status ProjectOp::Builder::SanityCheck() const {
  32. if (builder_columns_to_project_.empty()) {
  33. std::string err_msg("Columns to project is empty.");
  34. RETURN_STATUS_UNEXPECTED(err_msg);
  35. }
  36. return Status::OK();
  37. }
  38. Status ProjectOp::Builder::Build(std::shared_ptr<ProjectOp> *ptr) {
  39. RETURN_IF_NOT_OK(SanityCheck());
  40. *ptr = std::make_shared<ProjectOp>(builder_columns_to_project_);
  41. return Status::OK();
  42. }
  43. ProjectOp::ProjectOp(const std::vector<std::string> &columns_to_project)
  44. : PipelineOp(0), columns_to_project_(columns_to_project) {}
  45. void ProjectOp::Print(std::ostream &out, bool show_all) const {
  46. PipelineOp::Print(out, show_all);
  47. out << "ProjectOp: columns that are projected: ";
  48. for (size_t i = 0; i < columns_to_project_.size(); i++) {
  49. out << columns_to_project_[i] << " ";
  50. }
  51. out << '\n';
  52. }
  53. // Gets a buffer from the child operator and projects the buffer.
  54. Status ProjectOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  55. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id, retry_if_eoe));
  56. if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) {
  57. RETURN_IF_NOT_OK(Project(p_buffer));
  58. }
  59. return Status::OK();
  60. }
  61. Status ProjectOp::Project(std::unique_ptr<DataBuffer> *data_buffer) {
  62. std::unordered_map<std::string, int32_t> column_name_mapping = (*data_buffer)->column_name_map();
  63. std::unordered_map<std::string, int32_t> new_column_name_mapping;
  64. std::vector<int32_t> projected_column_indices;
  65. for (size_t i = 0; i < columns_to_project_.size(); i++) {
  66. std::string &current_column = columns_to_project_[i];
  67. if (column_name_mapping.find(current_column) == column_name_mapping.end()) {
  68. std::string err_msg = "ProjectOp: column " + current_column + " does not exist in this buffer.";
  69. RETURN_STATUS_UNEXPECTED(err_msg);
  70. }
  71. new_column_name_mapping[current_column] = i;
  72. projected_column_indices.push_back(column_name_mapping[current_column]);
  73. }
  74. std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>();
  75. while ((*data_buffer)->NumRows() > 0) {
  76. TensorRow current_row;
  77. RETURN_IF_NOT_OK((*data_buffer)->PopRow(&current_row));
  78. TensorRow new_row;
  79. (void)std::transform(projected_column_indices.begin(), projected_column_indices.end(), std::back_inserter(new_row),
  80. [&current_row](uint32_t x) { return current_row[x]; });
  81. new_tensor_table->push_back(new_row);
  82. }
  83. (*data_buffer)->set_tensor_table(std::move(new_tensor_table));
  84. (*data_buffer)->set_column_name_map(new_column_name_mapping);
  85. return Status::OK();
  86. }
  87. // Class functor operator () override.
  88. // Most dataset ops operate by launching a thread (see ExecutionTree).
  89. // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the
  90. // functor since this op runs inlined inside another operator. The function is overloaded to
  91. // ensure that it is not called by mistake (it will generate an error).
  92. Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
  93. int32_t ProjectOp::num_consumers() const {
  94. if (parent_.empty()) {
  95. MS_LOG(INFO) << "Project operator, no parent node, assuming it's the root and returning 1.";
  96. return 1;
  97. } else if (parent_[0] == nullptr) {
  98. MS_LOG(INFO) << "Project operator, pointer to the first parent is null. Returning 0.";
  99. return 0;
  100. } else {
  101. return parent_[0]->num_consumers();
  102. }
  103. }
  104. int32_t ProjectOp::num_producers() const {
  105. if (child_.empty() || child_[0] == nullptr) {
  106. MS_LOG(INFO) << "Project operator, pointer to child node is null. Returning 0.";
  107. return 0;
  108. } else {
  109. return child_[0]->num_producers();
  110. }
  111. }
  112. Status ProjectOp::EoeReceived(int32_t worker_id) {
  113. state_ = OpState::kDeOpIdle;
  114. return Status::OK();
  115. }
  116. Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
  117. } // namespace dataset
  118. } // namespace mindspore