You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

repeat_op.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_
  17. #define DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "dataset/engine/datasetops/pipeline_op.h"
  22. namespace mindspore {
  23. namespace dataset {
  24. class RepeatOp : public PipelineOp {
  25. public:
  26. static constexpr int32_t kInfiniteRepeat = -1;
  27. // The nested builder class inside of the RepeatOp is used to help manage all of the arguments
  28. // for constructing it. This repeat op is very simple though, so this builder is really just
  29. // provided for a consistent look and feel for creators of Dataset operators overall.
  30. class Builder {
  31. public:
  32. // Builder constructor. Creates the builder object.
  33. // @note No default args
  34. // @param count - The number of repeats to do
  35. // @return This is a constructor.
  36. explicit Builder(int32_t count);
  37. // Default destructor
  38. ~Builder() = default;
  39. // The builder "build" method creates the final object.
  40. // @return shared_ptr to the new StorageOp object
  41. Status Build(std::shared_ptr<RepeatOp> *);
  42. private:
  43. int32_t build_max_repeats_;
  44. Status SanityCheck() const;
  45. };
  46. // Constructor of the RepeatOp.
  47. // @note The builder class should be used to call it
  48. // @param count - The number of repeats to do
  49. explicit RepeatOp(int32_t count);
  50. // Destructor
  51. ~RepeatOp();
  52. // A print method typically used for debugging
  53. // @param out - The output stream to write output to
  54. // @param show_all - A bool to control if you want to show all info or just a summary
  55. void Print(std::ostream &out, bool show_all) const override;
  56. // << Stream output operator overload
  57. // @notes This allows you to write the debug print info using stream operators
  58. // @param out - reference to the output stream being overloaded
  59. // @param ro - reference to the RepeatOp to display
  60. // @return - the output stream must be returned
  61. friend std::ostream &operator<<(std::ostream &out, const RepeatOp &ro) {
  62. ro.Print(out, false);
  63. return out;
  64. }
  65. // Class functor operator () override.
  66. // Most dataset ops operate by launching a thread (see ExecutionTree).
  67. // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the
  68. // functor since this op runs inlined inside another operator. The function is overloaded to
  69. // ensure that it is not called by mistake (it will generate an error).
  70. // @return Status - The error code return
  71. Status operator()() override;
  72. // Base-class override for setting specific RepeatOp configurations. This code will be called
  73. // during the execution tree prepare phase BEFORE traversing down to child operators.
  74. uint32_t PrepareFlags() const override;
  75. // Base-class override for executing specific RepeatOp configurations. This code will be called
  76. // during the execution tree post-prepare phase when it is visiting this operator.
  77. Status PrepareNodePostAction() override;
  78. // This function returns the buffer that is at the top of our output connector. The caller is
  79. // typically our parent node, when the parent is asking us to provide the next buffer of data.
  80. // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
  81. // a buffer from our child.
  82. // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
  83. // this function will retry to pop the connector again and will get the non-EOE buffer if any.
  84. // @param p_buffer - output pointer to the buffer that it will fetch.
  85. // @param worker_id - The worker id
  86. // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
  87. // @return Status - The error code return
  88. Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
  89. // Base-class override for handling cases when an eoe is received.
  90. // @param worker_id - The worker id
  91. Status EoeReceived(int32_t worker_id) override;
  92. // Base-class override for handling cases when an eof is received.
  93. // @param worker_id - The worker id
  94. Status EofReceived(int32_t worker_id) override;
  95. // Base-class override. Return the number of workers in the first parent.
  96. // @param workerId - The worker id
  97. int32_t num_consumers() const override;
  98. // Base-class override. Return the number of producers in the first child.
  99. // @param workerId - The worker id
  100. int32_t num_producers() const override;
  101. // Base-class override for NodePass visitor acceptor.
  102. // @param p - Pointer to the NodePass to be accepted.
  103. // @param modified - Whether this node visit modified the pipeline.
  104. // @return - Status of the node visit.
  105. Status Accept(NodePass *p, bool *modified) override;
  106. // Op name getter
  107. // @return Name of the current Op
  108. std::string Name() const override { return "RepeatOp"; }
  109. private:
  110. int32_t max_repeats_; // The number of repeats that the user requested
  111. int32_t repeat_count_; // A counter for the current number of executed repeats
  112. std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
  113. };
  114. } // namespace dataset
  115. } // namespace mindspore
  116. #endif // DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_