Browse Source

!1272 [Dataset] MindData Tree Optimizer Infrastructure

Merge pull request !1272 from JunhanHu/minddata_opt
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
93e7c97a96
42 changed files with 826 additions and 5 deletions
  1. +1
    -0
      mindspore/ccsrc/dataset/CMakeLists.txt
  2. +3
    -2
      mindspore/ccsrc/dataset/engine/CMakeLists.txt
  3. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc
  4. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
  5. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
  6. +12
    -0
      mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
  7. +8
    -0
      mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc
  8. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h
  9. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
  10. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
  11. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
  12. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/map_op.h
  13. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/project_op.cc
  14. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/project_op.h
  15. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc
  16. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/rename_op.h
  17. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
  18. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
  19. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc
  20. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h
  21. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
  22. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
  23. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc
  24. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h
  25. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
  26. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
  27. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
  28. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
  29. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
  30. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
  31. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
  32. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/take_op.h
  33. +7
    -0
      mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc
  34. +6
    -0
      mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
  35. +47
    -1
      mindspore/ccsrc/dataset/engine/execution_tree.cc
  36. +32
    -2
      mindspore/ccsrc/dataset/engine/execution_tree.h
  37. +6
    -0
      mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
  38. +157
    -0
      mindspore/ccsrc/dataset/engine/opt/pass.cc
  39. +146
    -0
      mindspore/ccsrc/dataset/engine/opt/pass.h
  40. +111
    -0
      mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc
  41. +62
    -0
      mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h
  42. +46
    -0
      tests/ut/python/dataset/test_opt.py

+ 1
- 0
mindspore/ccsrc/dataset/CMakeLists.txt View File

@@ -66,6 +66,7 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops-source> $<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler> $<TARGET_OBJECTS:engine-datasetops-source-sampler>
$<TARGET_OBJECTS:engine-datasetops> $<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine> $<TARGET_OBJECTS:engine>
) )




+ 3
- 2
mindspore/ccsrc/dataset/engine/CMakeLists.txt View File

@@ -1,4 +1,5 @@
add_subdirectory(datasetops) add_subdirectory(datasetops)
add_subdirectory(opt)
if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_subdirectory(tdt) add_subdirectory(tdt)
endif () endif ()
@@ -14,7 +15,7 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})


if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt)
else() else()
add_dependencies(engine engine-datasetops engine-datasetops-source)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt)
endif () endif ()

+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc View File

@@ -22,6 +22,7 @@
#include "dataset/core/pybind_support.h" #include "dataset/core/pybind_support.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"


using float16 = Eigen::half; using float16 = Eigen::half;


@@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d
return Status::OK(); return Status::OK();
} }


// Visitor accept method for NodePass
Status BatchOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
}

} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h View File

@@ -192,6 +192,12 @@ class BatchOp : public ParallelOp {
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
float pad_val); float pad_val);


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
// it is only meant to be called by PadTensor. // it is only meant to be called by PadTensor.


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc View File

@@ -25,6 +25,7 @@
#include "dataset/engine/datasetops/device_queue_op.h" #include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"


#include "utils/log_adapter.h" #include "utils/log_adapter.h"


@@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() {
} }
return Status::OK(); return Status::OK();
} }

Status DatasetOp::Accept(NodePass *p, bool *modified) {
// DatasetOp is the base class of visitor target.
// This method will only be called if its derived class does not implement one.
return p->RunOnNode(shared_from_this(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 12
- 0
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h View File

@@ -32,6 +32,8 @@ class ExecutionTree;


class DataBuffer; class DataBuffer;


class NodePass;

// The base class DatasetOp is the main tree node. It is an abstract class, so // The base class DatasetOp is the main tree node. It is an abstract class, so
// the actual implementation of the operators will be derived from here. // the actual implementation of the operators will be derived from here.
class DatasetOp : public std::enable_shared_from_this<DatasetOp> { class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
@@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return - the column name map as a string // @return - the column name map as a string
std::string ColumnNameMapAsString() const; std::string ColumnNameMapAsString() const;


// Children Getter
// @return Vector or Children
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }

// Base method for NodePass visit.
// Subclass needs to override this if it requires special node visit access.
// Check "dataset/engine/opt/pass.h" for more details.
// @return Statue of the node visit
virtual Status Accept(NodePass *p, bool *modified);

protected: protected:
// Adds a parent operator to this operator // Adds a parent operator to this operator
// @notes External callers do not have access to this function. // @notes External callers do not have access to this function.


+ 8
- 0
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc View File

@@ -24,6 +24,7 @@
#include "dataset/engine/dataset_iterator.h" #include "dataset/engine/dataset_iterator.h"
#include "dataset/util/status.h" #include "dataset/util/status.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"


#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
#include "tdt/tsd_client.h" #include "tdt/tsd_client.h"
@@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n";
} }
} }

// Visitor accept method for NodePass
Status DeviceQueueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified);
}

} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h View File

@@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp {


Status operator()() override; Status operator()() override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// Name: checkExceptions(DataBuffer); // Name: checkExceptions(DataBuffer);
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc View File

@@ -27,6 +27,7 @@
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/tensor_op.h" #include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
@@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
} }
return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
} }

// Visitor accept method for NodePass
Status FilterOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h View File

@@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
// @param show_all A bool to control if you want to show all info or just a summary. // @param show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// predicate_func python callable which returns a boolean value. // predicate_func python callable which returns a boolean value.
py::function predicate_func_; py::function predicate_func_;


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc View File

@@ -27,6 +27,7 @@
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/tensor_op.h" #include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
@@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
column_name_id_map_ = final_col_name_id_map; column_name_id_map_ = final_col_name_id_map;
} }
} }

// Visitor accept method for NodePass
Status MapOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/map_op.h View File

@@ -171,6 +171,12 @@ class MapOp : public ParallelOp {
// @return the number of threads consuming data from previous op's output Connector. // @return the number of threads consuming data from previous op's output Connector.
int32_t num_consumers() const override; int32_t num_consumers() const override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// Local queues where worker threads can pop from. // Local queues where worker threads can pop from.
// Popping directly from the Connector can block if the previous designated threads haven't pop. // Popping directly from the Connector can block if the previous designated threads haven't pop.


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/project_op.cc View File

@@ -25,6 +25,7 @@
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"


namespace mindspore { namespace mindspore {
@@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) {
} }


Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }

// Visitor accept method for NodePass
Status ProjectOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/project_op.h View File

@@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp {
// @return Status - The error code returned. // @return Status - The error code returned.
Status EofReceived(int32_t worker_id) override; Status EofReceived(int32_t worker_id) override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
std::vector<std::string> columns_to_project_; std::vector<std::string> columns_to_project_;
std::vector<int32_t> projected_column_indices_; std::vector<int32_t> projected_column_indices_;


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc View File

@@ -24,6 +24,7 @@
#include "dataset/core/global_context.h" #include "dataset/core/global_context.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"


namespace mindspore { namespace mindspore {
@@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle; state_ = OpState::kDeOpIdle;
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status RenameOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/rename_op.h View File

@@ -110,6 +110,12 @@ class RenameOp : public PipelineOp {
// @return Status - The error code return // @return Status - The error code return
Status operator()() override; Status operator()() override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

protected: protected:
// Rename core functionality // Rename core functionality
Status RenameColumns(); Status RenameColumns();


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc View File

@@ -21,6 +21,7 @@
#include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"


#include "utils/log_adapter.h" #include "utils/log_adapter.h"


@@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const {
return child_[0]->num_producers(); return child_[0]->num_producers();
} }
} }

// Visitor accept method for NodePass
Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h View File

@@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp {
// @param workerId - The worker id // @param workerId - The worker id
int32_t num_producers() const override; int32_t num_producers() const override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
int32_t max_repeats_; // The number of repeats that the user requested int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats int32_t repeat_count_; // A counter for the current number of executed repeats


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc View File

@@ -30,6 +30,7 @@
#include "dataset/engine/dataset_iterator.h" #include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/random.h" #include "dataset/util/random.h"
#include "dataset/util/status.h" #include "dataset/util/status.h"


@@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
state_ = OpState::kDeOpIdle; state_ = OpState::kDeOpIdle;
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status ShuffleOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h View File

@@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp {
// @return Status - The error code return // @return Status - The error code return
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// Private function to add a new row to the shuffle buffer. // Private function to add a new row to the shuffle buffer.
// @return Status - The error code return // @return Status - The error code return


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc View File

@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"


#include "utils/log_adapter.h" #include "utils/log_adapter.h"


@@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) {
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status SkipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h View File

@@ -74,6 +74,12 @@ class SkipOp : public PipelineOp {
// @param worker_id - The worker id // @param worker_id - The worker id
Status EofReceived(int32_t worker_id) override; Status EofReceived(int32_t worker_id) override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
int32_t max_skips_; // The number of skips that the user requested int32_t max_skips_; // The number of skips that the user requested
int32_t skip_count_; // A counter for the current number of executed skips int32_t skip_count_; // A counter for the current number of executed skips


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc View File

@@ -20,6 +20,7 @@
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -250,5 +251,11 @@ Status GeneratorOp::Reset() {
wp_.Set(); wp_.Set();
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
} }

// Visitor accept method for NodePass
Status GeneratorOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h View File

@@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp {
// @return Status - The error code return // @return Status - The error code return
Status Reset() override; Status Reset() override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
py::function generator_function_; py::function generator_function_;
std::vector<std::string> column_names_; std::vector<std::string> column_names_;


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status ImageFolderOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h View File

@@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes, const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
int64_t dev_id = 0, int64_t num_dev = 1); int64_t dev_id = 0, int64_t num_dev = 1);


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return // @return Status - The error code return


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc View File

@@ -29,6 +29,7 @@
#include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"


namespace mindspore { namespace mindspore {
@@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
} }
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status MindRecordOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h View File

@@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp {


Status SetColumnsBlob(); Status SetColumnsBlob();


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);




+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -37,6 +37,7 @@
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/jagged_connector.h" #include "dataset/engine/jagged_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/path.h" #include "dataset/util/path.h"
#include "dataset/util/queue.h" #include "dataset/util/queue.h"
#include "dataset/util/random.h" #include "dataset/util/random.h"
@@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file


return rows_read; return rows_read;
} }

// Visitor accept method for NodePass
Status TFReaderOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h View File

@@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp {
static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1, static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1,
bool estimate = false); bool estimate = false);


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// The entry point for when workers are launched. // The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc View File

@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/take_op.h" #include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() {
tree_->AddToRepeatStack(shared_from_this()); tree_->AddToRepeatStack(shared_from_this());
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status TakeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/take_op.h View File

@@ -84,6 +84,12 @@ class TakeOp : public PipelineOp {
// before providing their own implementations. // before providing their own implementations.
Status PrepareNodePostAction() override; Status PrepareNodePostAction() override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
int32_t max_takes_; // The number of takes that the user requested int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes int32_t take_count_; // A counter for the current number of executed takes


+ 7
- 0
mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc View File

@@ -19,6 +19,7 @@
#include "dataset/core/constants.h" #include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h" #include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h" #include "dataset/core/global_context.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle; state_ = OpState::kDeOpIdle;
return Status::OK(); return Status::OK();
} }

// Visitor accept method for NodePass
Status ZipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h View File

@@ -104,6 +104,12 @@ class ZipOp : public PipelineOp {
// @return Status - The error code return // @return Status - The error code return
Status operator()() override; Status operator()() override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

private: private:
// Handles preprocessing of the main loop, used when starting new epoch // Handles preprocessing of the main loop, used when starting new epoch
Status prepare(TensorQTable *const table); Status prepare(TensorQTable *const table);


+ 47
- 1
mindspore/ccsrc/dataset/engine/execution_tree.cc View File

@@ -20,6 +20,8 @@
#include "dataset/engine/datasetops/shuffle_op.h" #include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"


#include "dataset/engine/opt/util/printer_pass.h"

namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor // Constructor
@@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
return Status::OK(); return Status::OK();
} }


// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status ExecutionTree::Prepare() {
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction());

// Optimization transformation
RETURN_IF_NOT_OK(this->Optimize());

// Post optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePostAction());

// Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK(this->PrepareDeprecated());
return Status::OK();
}

Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); }

Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); }

Status ExecutionTree::Optimize() {
// auto pp = new PrinterPass();
// bool modified = false;
// pp->Run(this, &modified);
return Status::OK();
}

// The driver of the prepare phase of the execution tree. The prepare phase will recursively // The driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get // walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution. // it ready for execution.
Status ExecutionTree::Prepare() {
//
// This driver is deprecated.
Status ExecutionTree::PrepareDeprecated() {
// Tree must be in pending prepare state before we can assign root to it // Tree must be in pending prepare state before we can assign root to it
if (tree_state_ != kDeTStatePrepare) { if (tree_state_ != kDeTStatePrepare) {
std::string err_msg = std::string err_msg =


+ 32
- 2
mindspore/ccsrc/dataset/engine/execution_tree.h View File

@@ -152,11 +152,41 @@ class ExecutionTree {
// @return the prepare flags // @return the prepare flags
uint32_t PrepareFlags() const { return prepare_flags_; } uint32_t PrepareFlags() const { return prepare_flags_; }


// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status Prepare();

// Compulsory transformation/action pre optimization.
// @return Status - The error code return
Status PrepareTreePreAction();

// Compulsory transformation/action post optimization.
// @return Status - The error code return
Status PrepareTreePostAction();

// Optimization transformation/action, optional.
// @return Status - The error code return
Status Optimize();

// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get // walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution. // it ready for execution.
// @return Status - The error code return // @return Status - The error code return
Status Prepare();
Status PrepareDeprecated();


// Recursive function used during prepare phase to visit a node and drive any pre- and post- // Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk. // node actions during a tree walk.


+ 6
- 0
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt View File

@@ -0,0 +1,6 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
pass.cc
util/printer_pass.cc
)

+ 157
- 0
mindspore/ccsrc/dataset/engine/opt/pass.cc View File

@@ -0,0 +1,157 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "dataset/engine/opt/pass.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"

namespace mindspore {
namespace dataset {

// Driver method for TreePass
Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); }

// Driver method for NodePass
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
std::shared_ptr<DatasetOp> root = tree->root();
if (traversalOrder_ == Order::DFS) {
// DFS
return DFSNodeVisit(root, modified);
} else if (traversalOrder_ == Order::BFS) {
// BFS
return BFSNodeVisit(root, modified);
}
return Status::OK();
}

// Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) {
for (const auto &c : node->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
}
return node->Accept(this, modified);
}

// Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified) {
// Initialize bfs queue with root
std::queue<std::shared_ptr<DatasetOp>> bfsQueue;
bfsQueue.push(root);

// BFS loop
while (!bfsQueue.empty()) {
// Pop the front of the bfs queue
auto curNode = bfsQueue.front();
bfsQueue.pop();

// Run node pass
RETURN_IF_NOT_OK(curNode->Accept(this, modified));

// Push children into bfs queue
for (const auto &c : curNode->Children()) {
bfsQueue.push(c);
}
}
return Status::OK();
}

Status NodePass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}

} // namespace dataset
} // namespace mindspore

+ 146
- 0
mindspore/ccsrc/dataset/engine/opt/pass.h View File

@@ -0,0 +1,146 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DATASET_ENGINE_OPT_PASS_H_
#define DATASET_ENGINE_OPT_PASS_H_

#include <memory>
#include <queue>

#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"

namespace mindspore {
namespace dataset {
class BatchOp;

class MapOp;

class ProjectOp;

class RenameOp;

class FilterOp;

class SkipOp;

class ShuffleOp;

class GeneratorOp;

class MindRecordOp;

class TFReaderOp;

class TakeOp;

class ZipOp;

class DeviceQueueOp;

class ImageFolderOp;

// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
public:
// Run the transformation pass again the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};

// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class TreePass : public Pass {
public:
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
Status Run(ExecutionTree *tree, bool *modified) final;

// Derived classes may implement the runOnTree function to implement tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};

// NodePass is a basic Pass class which performs transformation on Node visiting.
// NodePass implements Visitor design pattern.
class NodePass : public Pass {
public:
// Tree traversal order
enum Order { DFS, BFS };

// Constructor
// Default DFS traversal
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }

// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
Status Run(ExecutionTree *tree, bool *modified) final;

// Derived classes may implement the runOnNode function to implement node level tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }

// Visit methods to be overridden.
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
// of its own type and override "Accept" from DatasetOp.
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);

private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);

// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);

// Tree traversal order of the NodePass
Order traversalOrder_;
};

} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_OPT_PASS_H_

+ 111
- 0
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc View File

@@ -0,0 +1,111 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include "dataset/engine/opt/util/printer_pass.h"

namespace mindspore {
namespace dataset {

Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting DatasetOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting BatchOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting MapOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ProjectOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting RenameOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting FilterOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting SkipOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ShuffleOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting GeneratorOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting MindRecordOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting TFReaderOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting TakeOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ZipOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting DeviceQueueOp" << '\n';
return Status::OK();
}

Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ImageFolderOp" << '\n';
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 62
- 0
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H

#include <memory>
#include "dataset/engine/opt/pass.h"

namespace mindspore {
namespace dataset {

class PrinterPass : public NodePass {
public:
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;

Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
};

} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H

+ 46
- 0
tests/ut/python/dataset/test_opt.py View File

@@ -0,0 +1,46 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest

import mindspore.dataset as ds

# Generate 1d int numpy array from 0 - 63
def generator_1d():
for i in range(64):
yield (np.array([i]),)


def test_case_0():
"""
Test 1D Generator
"""

# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])

data1 = data1.shuffle(2)

data1 = data1.map(["data"], operations=(lambda x : x))

data1 = data1.batch(2)

i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
pass


if __name__ == "__main__":
test_case_0()

Loading…
Cancel
Save