You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

skip_op.cc 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <utility>
  18. #include "dataset/engine/data_buffer.h"
  19. #include "dataset/engine/datasetops/skip_op.h"
  20. #include "dataset/engine/db_connector.h"
  21. #include "dataset/engine/execution_tree.h"
  22. #include "utils/log_adapter.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. // Builder constructor. Creates the builder object.
  26. SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {}
  27. Status SkipOp::Builder::SanityCheck() const {
  28. if (build_max_skips_ < 0) {
  29. std::string err_msg("Skip count must be positive integer or 0.");
  30. RETURN_STATUS_UNEXPECTED(err_msg);
  31. }
  32. return Status::OK();
  33. }
  34. // The builder "build" method creates the final object.
  35. Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
  36. RETURN_IF_NOT_OK(SanityCheck());
  37. *ptr = std::make_shared<SkipOp>(build_max_skips_);
  38. return Status::OK();
  39. }
  40. // Constructor of the SkipOp.
  41. SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
  42. // Destructor
  43. SkipOp::~SkipOp() {}
  44. // A print method typically used for debugging
  45. void SkipOp::Print(std::ostream &out, bool show_all) const {
  46. // Call base class printer first
  47. PipelineOp::Print(out, show_all);
  48. // Then display our own stuff
  49. out << "SkipOp:"
  50. << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
  51. }
  52. // Since the buffer may contain multi rows, this function will drop the rows
  53. // that need to skip in it, and then return the buffer.
  54. Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  55. if (child_.empty()) {
  56. RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
  57. }
  58. std::unique_ptr<DataBuffer> buf;
  59. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  60. // Drop first max_skips_ rows
  61. while (skip_count_ < max_skips_) {
  62. if (buf->eoe() || buf->eof()) {
  63. break;
  64. }
  65. // Consider the rows of buffer more than 1
  66. TensorRow drop_row;
  67. int row_num = buf->NumRows();
  68. int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
  69. skip_count_ += drop_num;
  70. for (int i = 0; i < drop_num; i++) {
  71. RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
  72. }
  73. if (buf->NumRows() == 0) {
  74. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  75. }
  76. }
  77. // Handling eoe
  78. if (buf->eoe()) {
  79. RETURN_IF_NOT_OK(EoeReceived(worker_id));
  80. }
  81. // Handling eof
  82. if (buf->eof()) {
  83. RETURN_IF_NOT_OK(EofReceived(worker_id));
  84. }
  85. *p_buffer = std::move(buf);
  86. return Status::OK();
  87. }
  88. // Base-class override for handling cases when an eoe is received.
  89. Status SkipOp::EoeReceived(int32_t worker_id) {
  90. skip_count_ = 0;
  91. state_ = OpState::kDeOpIdle;
  92. return Status::OK();
  93. }
  94. // Class functor operator () override.
  95. // Most dataset ops operate by launching a thread (see ExecutionTree).
  96. // However, the SkipOp is defined as a inlined operator, so it is invalid to
  97. // launch the functor since this op runs inlined inside another operator. The
  98. // function is overloaded to ensure that it is not called by mistake (it will
  99. // generate an error).
  100. Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
  101. // Base-class override for handling cases when an eof is received.
  102. Status SkipOp::EofReceived(int32_t worker_id) {
  103. MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
  104. return Status::OK();
  105. }
  106. } // namespace dataset
  107. } // namespace mindspore