Merge pull request !1272 from JunhanHu/minddata_opttags/v0.3.0-alpha
| @@ -66,6 +66,7 @@ set(submodules | |||||
| $<TARGET_OBJECTS:engine-datasetops-source> | $<TARGET_OBJECTS:engine-datasetops-source> | ||||
| $<TARGET_OBJECTS:engine-datasetops-source-sampler> | $<TARGET_OBJECTS:engine-datasetops-source-sampler> | ||||
| $<TARGET_OBJECTS:engine-datasetops> | $<TARGET_OBJECTS:engine-datasetops> | ||||
| $<TARGET_OBJECTS:engine-opt> | |||||
| $<TARGET_OBJECTS:engine> | $<TARGET_OBJECTS:engine> | ||||
| ) | ) | ||||
| @@ -1,4 +1,5 @@ | |||||
| add_subdirectory(datasetops) | add_subdirectory(datasetops) | ||||
| add_subdirectory(opt) | |||||
| if (ENABLE_TDTQUE) | if (ENABLE_TDTQUE) | ||||
| add_subdirectory(tdt) | add_subdirectory(tdt) | ||||
| endif () | endif () | ||||
| @@ -14,7 +15,7 @@ add_library(engine OBJECT | |||||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | ||||
| if (ENABLE_TDTQUE) | if (ENABLE_TDTQUE) | ||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt) | |||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt) | |||||
| else() | else() | ||||
| add_dependencies(engine engine-datasetops engine-datasetops-source) | |||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt) | |||||
| endif () | endif () | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "dataset/core/pybind_support.h" | #include "dataset/core/pybind_support.h" | ||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| using float16 = Eigen::half; | using float16 = Eigen::half; | ||||
| @@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status BatchOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -192,6 +192,12 @@ class BatchOp : public ParallelOp { | |||||
| Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, | Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, | ||||
| float pad_val); | float pad_val); | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor | // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor | ||||
| // it is only meant to be called by PadTensor. | // it is only meant to be called by PadTensor. | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "dataset/engine/datasetops/device_queue_op.h" | #include "dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||||
| // DatasetOp is the base class of visitor target. | |||||
| // This method will only be called if its derived class does not implement one. | |||||
| return p->RunOnNode(shared_from_this(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,6 +32,8 @@ class ExecutionTree; | |||||
| class DataBuffer; | class DataBuffer; | ||||
| class NodePass; | |||||
| // The base class DatasetOp is the main tree node. It is an abstract class, so | // The base class DatasetOp is the main tree node. It is an abstract class, so | ||||
| // the actual implementation of the operators will be derived from here. | // the actual implementation of the operators will be derived from here. | ||||
| class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | ||||
| @@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // @return - the column name map as a string | // @return - the column name map as a string | ||||
| std::string ColumnNameMapAsString() const; | std::string ColumnNameMapAsString() const; | ||||
| // Children Getter | |||||
| // @return Vector or Children | |||||
| std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } | |||||
| // Base method for NodePass visit. | |||||
| // Subclass needs to override this if it requires special node visit access. | |||||
| // Check "dataset/engine/opt/pass.h" for more details. | |||||
| // @return Statue of the node visit | |||||
| virtual Status Accept(NodePass *p, bool *modified); | |||||
| protected: | protected: | ||||
| // Adds a parent operator to this operator | // Adds a parent operator to this operator | ||||
| // @notes External callers do not have access to this function. | // @notes External callers do not have access to this function. | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "dataset/engine/dataset_iterator.h" | #include "dataset/engine/dataset_iterator.h" | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| #include "tdt/tsd_client.h" | #include "tdt/tsd_client.h" | ||||
| @@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; | out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; | ||||
| } | } | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp { | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // Name: checkExceptions(DataBuffer); | // Name: checkExceptions(DataBuffer); | ||||
| // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp | // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/kernels/tensor_op.h" | #include "dataset/kernels/tensor_op.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| @@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate | |||||
| } | } | ||||
| return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); | return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status FilterOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -121,6 +121,12 @@ class FilterOp : public ParallelOp { | |||||
| // @param show_all A bool to control if you want to show all info or just a summary. | // @param show_all A bool to control if you want to show all info or just a summary. | ||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // predicate_func python callable which returns a boolean value. | // predicate_func python callable which returns a boolean value. | ||||
| py::function predicate_func_; | py::function predicate_func_; | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/kernels/tensor_op.h" | #include "dataset/kernels/tensor_op.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| @@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name | |||||
| column_name_id_map_ = final_col_name_id_map; | column_name_id_map_ = final_col_name_id_map; | ||||
| } | } | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status MapOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -171,6 +171,12 @@ class MapOp : public ParallelOp { | |||||
| // @return the number of threads consuming data from previous op's output Connector. | // @return the number of threads consuming data from previous op's output Connector. | ||||
| int32_t num_consumers() const override; | int32_t num_consumers() const override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // Local queues where worker threads can pop from. | // Local queues where worker threads can pop from. | ||||
| // Popping directly from the Connector can block if the previous designated threads haven't pop. | // Popping directly from the Connector can block if the previous designated threads haven't pop. | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) { | |||||
| } | } | ||||
| Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } | Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } | ||||
| // Visitor accept method for NodePass | |||||
| Status ProjectOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp { | |||||
| // @return Status - The error code returned. | // @return Status - The error code returned. | ||||
| Status EofReceived(int32_t worker_id) override; | Status EofReceived(int32_t worker_id) override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| std::vector<std::string> columns_to_project_; | std::vector<std::string> columns_to_project_; | ||||
| std::vector<int32_t> projected_column_indices_; | std::vector<int32_t> projected_column_indices_; | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "dataset/core/global_context.h" | #include "dataset/core/global_context.h" | ||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) { | |||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status RenameOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -110,6 +110,12 @@ class RenameOp : public PipelineOp { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| protected: | protected: | ||||
| // Rename core functionality | // Rename core functionality | ||||
| Status RenameColumns(); | Status RenameColumns(); | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "dataset/engine/datasetops/repeat_op.h" | #include "dataset/engine/datasetops/repeat_op.h" | ||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const { | |||||
| return child_[0]->num_producers(); | return child_[0]->num_producers(); | ||||
| } | } | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status RepeatOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp { | |||||
| // @param workerId - The worker id | // @param workerId - The worker id | ||||
| int32_t num_producers() const override; | int32_t num_producers() const override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| int32_t max_repeats_; // The number of repeats that the user requested | int32_t max_repeats_; // The number of repeats that the user requested | ||||
| int32_t repeat_count_; // A counter for the current number of executed repeats | int32_t repeat_count_; // A counter for the current number of executed repeats | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "dataset/engine/dataset_iterator.h" | #include "dataset/engine/dataset_iterator.h" | ||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/util/random.h" | #include "dataset/util/random.h" | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| @@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) { | |||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status ShuffleOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status EoeReceived(int32_t worker_id) override; | Status EoeReceived(int32_t worker_id) override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // Private function to add a new row to the shuffle buffer. | // Private function to add a new row to the shuffle buffer. | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "dataset/engine/datasetops/skip_op.h" | #include "dataset/engine/datasetops/skip_op.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) { | |||||
| MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; | MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status SkipOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -74,6 +74,12 @@ class SkipOp : public PipelineOp { | |||||
| // @param worker_id - The worker id | // @param worker_id - The worker id | ||||
| Status EofReceived(int32_t worker_id) override; | Status EofReceived(int32_t worker_id) override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| int32_t max_skips_; // The number of skips that the user requested | int32_t max_skips_; // The number of skips that the user requested | ||||
| int32_t skip_count_; // A counter for the current number of executed skips | int32_t skip_count_; // A counter for the current number of executed skips | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -250,5 +251,11 @@ Status GeneratorOp::Reset() { | |||||
| wp_.Set(); | wp_.Set(); | ||||
| return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status GeneratorOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status Reset() override; | Status Reset() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t | |||||
| (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); | (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status ImageFolderOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes, | const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes, | ||||
| int64_t dev_id = 0, int64_t num_dev = 1); | int64_t dev_id = 0, int64_t num_dev = 1); | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "dataset/engine/datasetops/dataset_op.h" | #include "dataset/engine/datasetops/dataset_op.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status MindRecordOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp { | |||||
| Status SetColumnsBlob(); | Status SetColumnsBlob(); | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | ||||
| @@ -37,6 +37,7 @@ | |||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/jagged_connector.h" | #include "dataset/engine/jagged_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/util/path.h" | #include "dataset/util/path.h" | ||||
| #include "dataset/util/queue.h" | #include "dataset/util/queue.h" | ||||
| #include "dataset/util/random.h" | #include "dataset/util/random.h" | ||||
| @@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file | |||||
| return rows_read; | return rows_read; | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status TFReaderOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp { | |||||
| static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1, | static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1, | ||||
| bool estimate = false); | bool estimate = false); | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // The entry point for when workers are launched. | // The entry point for when workers are launched. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "dataset/engine/datasetops/take_op.h" | #include "dataset/engine/datasetops/take_op.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() { | |||||
| tree_->AddToRepeatStack(shared_from_this()); | tree_->AddToRepeatStack(shared_from_this()); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status TakeOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -84,6 +84,12 @@ class TakeOp : public PipelineOp { | |||||
| // before providing their own implementations. | // before providing their own implementations. | ||||
| Status PrepareNodePostAction() override; | Status PrepareNodePostAction() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| int32_t max_takes_; // The number of takes that the user requested | int32_t max_takes_; // The number of takes that the user requested | ||||
| int32_t take_count_; // A counter for the current number of executed takes | int32_t take_count_; // A counter for the current number of executed takes | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include "dataset/core/constants.h" | #include "dataset/core/constants.h" | ||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/core/config_manager.h" | #include "dataset/core/config_manager.h" | ||||
| #include "dataset/core/global_context.h" | #include "dataset/core/global_context.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) { | |||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accept method for NodePass | |||||
| Status ZipOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -104,6 +104,12 @@ class ZipOp : public PipelineOp { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // Handles preprocessing of the main loop, used when starting new epoch | // Handles preprocessing of the main loop, used when starting new epoch | ||||
| Status prepare(TensorQTable *const table); | Status prepare(TensorQTable *const table); | ||||
| @@ -20,6 +20,8 @@ | |||||
| #include "dataset/engine/datasetops/shuffle_op.h" | #include "dataset/engine/datasetops/shuffle_op.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| #include "dataset/engine/opt/util/printer_pass.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| @@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // The driver of the prepare phase of the execution tree. | |||||
| // Prepare phase consists of three sub phases | |||||
| // | |||||
| // 1. PrepareTreePreAction() | |||||
| // Compulsory transformation/action pre optimization. | |||||
| // For example, CacheOp Insertion | |||||
| // | |||||
| // 2. Optimize() | |||||
| // Optimization transformation/action, optional | |||||
| // For example, MapOp Fusion | |||||
| // | |||||
| // 3. PrepareTreePostAction() | |||||
| // Compulsory transformation/action post optimization. | |||||
| // For example, repeatOp inlining | |||||
| // | |||||
| // @return Status - The error code return | |||||
| Status ExecutionTree::Prepare() { | |||||
| // Pre optimization compulsory transformation | |||||
| RETURN_IF_NOT_OK(this->PrepareTreePreAction()); | |||||
| // Optimization transformation | |||||
| RETURN_IF_NOT_OK(this->Optimize()); | |||||
| // Post optimization compulsory transformation | |||||
| RETURN_IF_NOT_OK(this->PrepareTreePostAction()); | |||||
| // Existing transformation implementation, will be removed later | |||||
| RETURN_IF_NOT_OK(this->PrepareDeprecated()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); } | |||||
| Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); } | |||||
| Status ExecutionTree::Optimize() { | |||||
| // auto pp = new PrinterPass(); | |||||
| // bool modified = false; | |||||
| // pp->Run(this, &modified); | |||||
| return Status::OK(); | |||||
| } | |||||
| // The driver of the prepare phase of the execution tree. The prepare phase will recursively | // The driver of the prepare phase of the execution tree. The prepare phase will recursively | ||||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | // walk the tree to perform modifications to the tree or specific nodes within the tree to get | ||||
| // it ready for execution. | // it ready for execution. | ||||
| Status ExecutionTree::Prepare() { | |||||
| // | |||||
| // This driver is deprecated. | |||||
| Status ExecutionTree::PrepareDeprecated() { | |||||
| // Tree must be in pending prepare state before we can assign root to it | // Tree must be in pending prepare state before we can assign root to it | ||||
| if (tree_state_ != kDeTStatePrepare) { | if (tree_state_ != kDeTStatePrepare) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| @@ -152,11 +152,41 @@ class ExecutionTree { | |||||
| // @return the prepare flags | // @return the prepare flags | ||||
| uint32_t PrepareFlags() const { return prepare_flags_; } | uint32_t PrepareFlags() const { return prepare_flags_; } | ||||
| // The driver of the prepare phase of the execution tree. The prepare phase will recursively | |||||
| // The driver of the prepare phase of the execution tree. | |||||
| // Prepare phase consists of three sub phases | |||||
| // | |||||
| // 1. PrepareTreePreAction() | |||||
| // Compulsory transformation/action pre optimization. | |||||
| // For example, CacheOp Insertion | |||||
| // | |||||
| // 2. Optimize() | |||||
| // Optimization transformation/action, optional | |||||
| // For example, MapOp Fusion | |||||
| // | |||||
| // 3. PrepareTreePostAction() | |||||
| // Compulsory transformation/action post optimization. | |||||
| // For example, repeatOp inlining | |||||
| // | |||||
| // @return Status - The error code return | |||||
| Status Prepare(); | |||||
| // Compulsory transformation/action pre optimization. | |||||
| // @return Status - The error code return | |||||
| Status PrepareTreePreAction(); | |||||
| // Compulsory transformation/action post optimization. | |||||
| // @return Status - The error code return | |||||
| Status PrepareTreePostAction(); | |||||
| // Optimization transformation/action, optional. | |||||
| // @return Status - The error code return | |||||
| Status Optimize(); | |||||
| // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | |||||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | // walk the tree to perform modifications to the tree or specific nodes within the tree to get | ||||
| // it ready for execution. | // it ready for execution. | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status Prepare(); | |||||
| Status PrepareDeprecated(); | |||||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | // Recursive function used during prepare phase to visit a node and drive any pre- and post- | ||||
| // node actions during a tree walk. | // node actions during a tree walk. | ||||
| @@ -0,0 +1,6 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||||
| add_library(engine-opt OBJECT | |||||
| pass.cc | |||||
| util/printer_pass.cc | |||||
| ) | |||||
| @@ -0,0 +1,157 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/engine/datasetops/batch_op.h" | |||||
| #include "dataset/engine/datasetops/dataset_op.h" | |||||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||||
| #include "dataset/engine/datasetops/map_op.h" | |||||
| #include "dataset/engine/datasetops/project_op.h" | |||||
| #include "dataset/engine/datasetops/rename_op.h" | |||||
| #include "dataset/engine/datasetops/filter_op.h" | |||||
| #include "dataset/engine/datasetops/repeat_op.h" | |||||
| #include "dataset/engine/datasetops/skip_op.h" | |||||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||||
| #include "dataset/engine/datasetops/source/generator_op.h" | |||||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | |||||
| #include "dataset/engine/datasetops/source/storage_op.h" | |||||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||||
| #include "dataset/engine/datasetops/take_op.h" | |||||
| #include "dataset/engine/datasetops/zip_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Driver method for TreePass | |||||
| Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); } | |||||
| // Driver method for NodePass | |||||
| Status NodePass::Run(ExecutionTree *tree, bool *modified) { | |||||
| std::shared_ptr<DatasetOp> root = tree->root(); | |||||
| if (traversalOrder_ == Order::DFS) { | |||||
| // DFS | |||||
| return DFSNodeVisit(root, modified); | |||||
| } else if (traversalOrder_ == Order::BFS) { | |||||
| // BFS | |||||
| return BFSNodeVisit(root, modified); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to perform DFS visit | |||||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) { | |||||
| for (const auto &c : node->Children()) { | |||||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | |||||
| } | |||||
| return node->Accept(this, modified); | |||||
| } | |||||
| // Helper function to perform BFS visit | |||||
| Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified) { | |||||
| // Initialize bfs queue with root | |||||
| std::queue<std::shared_ptr<DatasetOp>> bfsQueue; | |||||
| bfsQueue.push(root); | |||||
| // BFS loop | |||||
| while (!bfsQueue.empty()) { | |||||
| // Pop the front of the bfs queue | |||||
| auto curNode = bfsQueue.front(); | |||||
| bfsQueue.pop(); | |||||
| // Run node pass | |||||
| RETURN_IF_NOT_OK(curNode->Accept(this, modified)); | |||||
| // Push children into bfs queue | |||||
| for (const auto &c : curNode->Children()) { | |||||
| bfsQueue.push(c); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,146 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_OPT_PASS_H_ | |||||
| #define DATASET_ENGINE_OPT_PASS_H_ | |||||
| #include <memory> | |||||
| #include <queue> | |||||
| #include "dataset/engine/execution_tree.h" | |||||
| #include "dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class BatchOp; | |||||
| class MapOp; | |||||
| class ProjectOp; | |||||
| class RenameOp; | |||||
| class FilterOp; | |||||
| class SkipOp; | |||||
| class ShuffleOp; | |||||
| class GeneratorOp; | |||||
| class MindRecordOp; | |||||
| class TFReaderOp; | |||||
| class TakeOp; | |||||
| class ZipOp; | |||||
| class DeviceQueueOp; | |||||
| class ImageFolderOp; | |||||
| // The base class Pass is the basic unit of tree transformation. | |||||
| // The actual implementation of the passes will be derived from here. | |||||
| class Pass : public std::enable_shared_from_this<Pass> { | |||||
| public: | |||||
| // Run the transformation pass again the execution tree. | |||||
| // @param tree - Pointer to the execution tree to be transformed. | |||||
| // @param modified - Pointer to the modified flag, | |||||
| virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||||
| }; | |||||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | |||||
| class TreePass : public Pass { | |||||
| public: | |||||
| // Run the transformation pass against the execution tree. | |||||
| // @param tree - Pointer to the execution tree to be transformed. | |||||
| // @param modified - Pointer to the modified flag, | |||||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||||
| // Derived classes may implement the runOnTree function to implement tree transformation. | |||||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||||
| // @return Status - The error code return | |||||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||||
| }; | |||||
| // NodePass is a basic Pass class which performs transformation on Node visiting. | |||||
| // NodePass implements Visitor design pattern. | |||||
| class NodePass : public Pass { | |||||
| public: | |||||
| // Tree traversal order | |||||
| enum Order { DFS, BFS }; | |||||
| // Constructor | |||||
| // Default DFS traversal | |||||
| explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } | |||||
| // Run the transformation pass against the execution tree. | |||||
| // @param tree - Pointer to the execution tree to be transformed. | |||||
| // @param modified - Pointer to the modified flag, | |||||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||||
| // Derived classes may implement the runOnNode function to implement node level tree transformation. | |||||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||||
| // @return Status - The error code return | |||||
| virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||||
| // Visit methods to be overridden. | |||||
| // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode | |||||
| // of its own type and override "Accept" from DatasetOp. | |||||
| virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | |||||
| private: | |||||
| // Helper function to perform DFS visit | |||||
| Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); | |||||
| // Helper function to perform BFS visit | |||||
| Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified); | |||||
| // Tree traversal order of the NodePass | |||||
| Order traversalOrder_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_OPT_PASS_H_ | |||||
| @@ -0,0 +1,111 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "dataset/engine/opt/util/printer_pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting DatasetOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting BatchOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting MapOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting ProjectOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting RenameOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting FilterOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting SkipOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting ShuffleOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting GeneratorOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting MindRecordOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting TFReaderOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting TakeOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting ZipOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting DeviceQueueOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| std::cout << "Visiting ImageFolderOp" << '\n'; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||||
| #define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||||
| #include <memory> | |||||
| #include "dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class PrinterPass : public NodePass { | |||||
| public: | |||||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.dataset as ds | |||||
| # Generate 1d int numpy array from 0 - 63 | |||||
| def generator_1d(): | |||||
| for i in range(64): | |||||
| yield (np.array([i]),) | |||||
| def test_case_0(): | |||||
| """ | |||||
| Test 1D Generator | |||||
| """ | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||||
| data1 = data1.shuffle(2) | |||||
| data1 = data1.map(["data"], operations=(lambda x : x)) | |||||
| data1 = data1.batch(2) | |||||
| i = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| pass | |||||
| if __name__ == "__main__": | |||||
| test_case_0() | |||||