Merge pull request !1272 from JunhanHu/minddata_opttags/v0.3.0-alpha
| @@ -66,6 +66,7 @@ set(submodules | |||
| $<TARGET_OBJECTS:engine-datasetops-source> | |||
| $<TARGET_OBJECTS:engine-datasetops-source-sampler> | |||
| $<TARGET_OBJECTS:engine-datasetops> | |||
| $<TARGET_OBJECTS:engine-opt> | |||
| $<TARGET_OBJECTS:engine> | |||
| ) | |||
| @@ -1,4 +1,5 @@ | |||
| add_subdirectory(datasetops) | |||
| add_subdirectory(opt) | |||
| if (ENABLE_TDTQUE) | |||
| add_subdirectory(tdt) | |||
| endif () | |||
| @@ -14,7 +15,7 @@ add_library(engine OBJECT | |||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | |||
| if (ENABLE_TDTQUE) | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt) | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt) | |||
| else() | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source) | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt) | |||
| endif () | |||
| @@ -22,6 +22,7 @@ | |||
| #include "dataset/core/pybind_support.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| using float16 = Eigen::half; | |||
| @@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status BatchOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -192,6 +192,12 @@ class BatchOp : public ParallelOp { | |||
| Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, | |||
| float pad_val); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor | |||
| // it is only meant to be called by PadTensor. | |||
| @@ -25,6 +25,7 @@ | |||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||
| // DatasetOp is the base class of visitor target. | |||
| // This method will only be called if its derived class does not implement one. | |||
| return p->RunOnNode(shared_from_this(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,8 @@ class ExecutionTree; | |||
| class DataBuffer; | |||
| class NodePass; | |||
| // The base class DatasetOp is the main tree node. It is an abstract class, so | |||
| // the actual implementation of the operators will be derived from here. | |||
| class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| @@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| // @return - the column name map as a string | |||
| std::string ColumnNameMapAsString() const; | |||
| // Children Getter | |||
| // @return Vector or Children | |||
| std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } | |||
| // Base method for NodePass visit. | |||
| // Subclass needs to override this if it requires special node visit access. | |||
| // Check "dataset/engine/opt/pass.h" for more details. | |||
| // @return Statue of the node visit | |||
| virtual Status Accept(NodePass *p, bool *modified); | |||
| protected: | |||
| // Adds a parent operator to this operator | |||
| // @notes External callers do not have access to this function. | |||
| @@ -24,6 +24,7 @@ | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #ifdef ENABLE_TDTQUE | |||
| #include "tdt/tsd_client.h" | |||
| @@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; | |||
| } | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp { | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // Name: checkExceptions(DataBuffer); | |||
| // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp | |||
| @@ -27,6 +27,7 @@ | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/util/task_manager.h" | |||
| @@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate | |||
| } | |||
| return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status FilterOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -121,6 +121,12 @@ class FilterOp : public ParallelOp { | |||
| // @param show_all A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // predicate_func python callable which returns a boolean value. | |||
| py::function predicate_func_; | |||
| @@ -27,6 +27,7 @@ | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/util/task_manager.h" | |||
| @@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name | |||
| column_name_id_map_ = final_col_name_id_map; | |||
| } | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status MapOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -171,6 +171,12 @@ class MapOp : public ParallelOp { | |||
| // @return the number of threads consuming data from previous op's output Connector. | |||
| int32_t num_consumers() const override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // Local queues where worker threads can pop from. | |||
| // Popping directly from the Connector can block if the previous designated threads haven't pop. | |||
| @@ -25,6 +25,7 @@ | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) { | |||
| } | |||
| Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } | |||
| // Visitor accept method for NodePass | |||
| Status ProjectOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp { | |||
| // @return Status - The error code returned. | |||
| Status EofReceived(int32_t worker_id) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| std::vector<std::string> columns_to_project_; | |||
| std::vector<int32_t> projected_column_indices_; | |||
| @@ -24,6 +24,7 @@ | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status RenameOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -110,6 +110,12 @@ class RenameOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| protected: | |||
| // Rename core functionality | |||
| Status RenameColumns(); | |||
| @@ -21,6 +21,7 @@ | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status RepeatOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp { | |||
| // @param workerId - The worker id | |||
| int32_t num_producers() const override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t max_repeats_; // The number of repeats that the user requested | |||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||
| @@ -30,6 +30,7 @@ | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| @@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ShuffleOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // Private function to add a new row to the shuffle buffer. | |||
| // @return Status - The error code return | |||
| @@ -22,6 +22,7 @@ | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) { | |||
| MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status SkipOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -74,6 +74,12 @@ class SkipOp : public PipelineOp { | |||
| // @param worker_id - The worker id | |||
| Status EofReceived(int32_t worker_id) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t max_skips_; // The number of skips that the user requested | |||
| int32_t skip_count_; // A counter for the current number of executed skips | |||
| @@ -20,6 +20,7 @@ | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -250,5 +251,11 @@ Status GeneratorOp::Reset() { | |||
| wp_.Set(); | |||
| return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status GeneratorOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status Reset() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| py::function generator_function_; | |||
| std::vector<std::string> column_names_; | |||
| @@ -22,6 +22,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t | |||
| (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ImageFolderOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes, | |||
| int64_t dev_id = 0, int64_t num_dev = 1); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| @@ -29,6 +29,7 @@ | |||
| #include "dataset/engine/datasetops/dataset_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status MindRecordOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp { | |||
| Status SetColumnsBlob(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | |||
| @@ -37,6 +37,7 @@ | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/util/path.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/random.h" | |||
| @@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file | |||
| return rows_read; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status TFReaderOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp { | |||
| static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1, | |||
| bool estimate = false); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -22,6 +22,7 @@ | |||
| #include "dataset/engine/datasetops/take_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() { | |||
| tree_->AddToRepeatStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status TakeOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -84,6 +84,12 @@ class TakeOp : public PipelineOp { | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t max_takes_; // The number of takes that the user requested | |||
| int32_t take_count_; // A counter for the current number of executed takes | |||
| @@ -19,6 +19,7 @@ | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ZipOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -104,6 +104,12 @@ class ZipOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| Status prepare(TensorQTable *const table); | |||
| @@ -20,6 +20,8 @@ | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/opt/util/printer_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| @@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||
| return Status::OK(); | |||
| } | |||
| // The driver of the prepare phase of the execution tree. | |||
| // Prepare phase consists of three sub phases | |||
| // | |||
| // 1. PrepareTreePreAction() | |||
| // Compulsory transformation/action pre optimization. | |||
| // For example, CacheOp Insertion | |||
| // | |||
| // 2. Optimize() | |||
| // Optimization transformation/action, optional | |||
| // For example, MapOp Fusion | |||
| // | |||
| // 3. PrepareTreePostAction() | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| Status ExecutionTree::Prepare() { | |||
| // Pre optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PrepareTreePreAction()); | |||
| // Optimization transformation | |||
| RETURN_IF_NOT_OK(this->Optimize()); | |||
| // Post optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PrepareTreePostAction()); | |||
| // Existing transformation implementation, will be removed later | |||
| RETURN_IF_NOT_OK(this->PrepareDeprecated()); | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); } | |||
| Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); } | |||
| Status ExecutionTree::Optimize() { | |||
| // auto pp = new PrinterPass(); | |||
| // bool modified = false; | |||
| // pp->Run(this, &modified); | |||
| return Status::OK(); | |||
| } | |||
| // The driver of the prepare phase of the execution tree. The prepare phase will recursively | |||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | |||
| // it ready for execution. | |||
| Status ExecutionTree::Prepare() { | |||
| // | |||
| // This driver is deprecated. | |||
| Status ExecutionTree::PrepareDeprecated() { | |||
| // Tree must be in pending prepare state before we can assign root to it | |||
| if (tree_state_ != kDeTStatePrepare) { | |||
| std::string err_msg = | |||
| @@ -152,11 +152,41 @@ class ExecutionTree { | |||
| // @return the prepare flags | |||
| uint32_t PrepareFlags() const { return prepare_flags_; } | |||
| // The driver of the prepare phase of the execution tree. The prepare phase will recursively | |||
| // The driver of the prepare phase of the execution tree. | |||
| // Prepare phase consists of three sub phases | |||
| // | |||
| // 1. PrepareTreePreAction() | |||
| // Compulsory transformation/action pre optimization. | |||
| // For example, CacheOp Insertion | |||
| // | |||
| // 2. Optimize() | |||
| // Optimization transformation/action, optional | |||
| // For example, MapOp Fusion | |||
| // | |||
| // 3. PrepareTreePostAction() | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| Status Prepare(); | |||
| // Compulsory transformation/action pre optimization. | |||
| // @return Status - The error code return | |||
| Status PrepareTreePreAction(); | |||
| // Compulsory transformation/action post optimization. | |||
| // @return Status - The error code return | |||
| Status PrepareTreePostAction(); | |||
| // Optimization transformation/action, optional. | |||
| // @return Status - The error code return | |||
| Status Optimize(); | |||
| // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | |||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | |||
| // it ready for execution. | |||
| // @return Status - The error code return | |||
| Status Prepare(); | |||
| Status PrepareDeprecated(); | |||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | |||
| // node actions during a tree walk. | |||
| @@ -0,0 +1,6 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-opt OBJECT | |||
| pass.cc | |||
| util/printer_pass.cc | |||
| ) | |||
| @@ -0,0 +1,157 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/engine/datasetops/batch_op.h" | |||
| #include "dataset/engine/datasetops/dataset_op.h" | |||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||
| #include "dataset/engine/datasetops/map_op.h" | |||
| #include "dataset/engine/datasetops/project_op.h" | |||
| #include "dataset/engine/datasetops/rename_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/engine/datasetops/source/generator_op.h" | |||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "dataset/engine/datasetops/source/storage_op.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/take_op.h" | |||
| #include "dataset/engine/datasetops/zip_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Driver method for TreePass | |||
| Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); } | |||
| // Driver method for NodePass | |||
| Status NodePass::Run(ExecutionTree *tree, bool *modified) { | |||
| std::shared_ptr<DatasetOp> root = tree->root(); | |||
| if (traversalOrder_ == Order::DFS) { | |||
| // DFS | |||
| return DFSNodeVisit(root, modified); | |||
| } else if (traversalOrder_ == Order::BFS) { | |||
| // BFS | |||
| return BFSNodeVisit(root, modified); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to perform DFS visit | |||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| for (const auto &c : node->Children()) { | |||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | |||
| } | |||
| return node->Accept(this, modified); | |||
| } | |||
| // Helper function to perform BFS visit | |||
| Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified) { | |||
| // Initialize bfs queue with root | |||
| std::queue<std::shared_ptr<DatasetOp>> bfsQueue; | |||
| bfsQueue.push(root); | |||
| // BFS loop | |||
| while (!bfsQueue.empty()) { | |||
| // Pop the front of the bfs queue | |||
| auto curNode = bfsQueue.front(); | |||
| bfsQueue.pop(); | |||
| // Run node pass | |||
| RETURN_IF_NOT_OK(curNode->Accept(this, modified)); | |||
| // Push children into bfs queue | |||
| for (const auto &c : curNode->Children()) { | |||
| bfsQueue.push(c); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,146 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_H_ | |||
| #include <memory> | |||
| #include <queue> | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class BatchOp; | |||
| class MapOp; | |||
| class ProjectOp; | |||
| class RenameOp; | |||
| class FilterOp; | |||
| class SkipOp; | |||
| class ShuffleOp; | |||
| class GeneratorOp; | |||
| class MindRecordOp; | |||
| class TFReaderOp; | |||
| class TakeOp; | |||
| class ZipOp; | |||
| class DeviceQueueOp; | |||
| class ImageFolderOp; | |||
| // The base class Pass is the basic unit of tree transformation. | |||
| // The actual implementation of the passes will be derived from here. | |||
| class Pass : public std::enable_shared_from_this<Pass> { | |||
| public: | |||
| // Run the transformation pass again the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| }; | |||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | |||
| class TreePass : public Pass { | |||
| public: | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||
| // Derived classes may implement the runOnTree function to implement tree transformation. | |||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| // @return Status - The error code return | |||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| }; | |||
| // NodePass is a basic Pass class which performs transformation on Node visiting. | |||
| // NodePass implements Visitor design pattern. | |||
| class NodePass : public Pass { | |||
| public: | |||
| // Tree traversal order | |||
| enum Order { DFS, BFS }; | |||
| // Constructor | |||
| // Default DFS traversal | |||
| explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||
| // Derived classes may implement the runOnNode function to implement node level tree transformation. | |||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| // @return Status - The error code return | |||
| virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||
| // Visit methods to be overridden. | |||
| // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode | |||
| // of its own type and override "Accept" from DatasetOp. | |||
| virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | |||
| private: | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); | |||
| // Helper function to perform BFS visit | |||
| Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified); | |||
| // Tree traversal order of the NodePass | |||
| Order traversalOrder_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_H_ | |||
| @@ -0,0 +1,111 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "dataset/engine/opt/util/printer_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting DatasetOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting BatchOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting MapOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ProjectOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting RenameOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting FilterOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting SkipOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ShuffleOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting GeneratorOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting MindRecordOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting TFReaderOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting TakeOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ZipOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting DeviceQueueOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ImageFolderOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||
| #define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||
| #include <memory> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class PrinterPass : public NodePass { | |||
| public: | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| # Generate 1d int numpy array from 0 - 63 | |||
| def generator_1d(): | |||
| for i in range(64): | |||
| yield (np.array([i]),) | |||
| def test_case_0(): | |||
| """ | |||
| Test 1D Generator | |||
| """ | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data1 = data1.shuffle(2) | |||
| data1 = data1.map(["data"], operations=(lambda x : x)) | |||
| data1 = data1.batch(2) | |||
| i = 0 | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| pass | |||
| if __name__ == "__main__": | |||
| test_case_0() | |||