Browse Source

Add optimizer to IR tree #1

tags/v1.1.0
Nat Sutyanyong 5 years ago
parent
commit
80d02d6dcd
4 changed files with 712 additions and 0 deletions
  1. +14
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  2. +19
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  3. +432
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
  4. +247
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.h

+ 14
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <set> #include <set>


#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"


namespace mindspore { namespace mindspore {
@@ -227,6 +228,7 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers; num_workers_ = num_workers;
return shared_from_this(); return shared_from_this();
} }

DatasetNode::DatasetNode() { DatasetNode::DatasetNode() {
// Fetch some default value from config manager // Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
@@ -236,5 +238,17 @@ DatasetNode::DatasetNode() {
worker_connector_size_ = cfg->worker_connector_size(); worker_connector_size_ = cfg->worker_connector_size();
} }


// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Status DatasetNode::Accept(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->Visit(shared_from_this(), modified);
}

// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
// after all child nodes are visited.
Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 19
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -31,6 +31,7 @@ namespace dataset {


class Dataset; class Dataset;
class SamplerObj; class SamplerObj;
class NodePass;


#define RETURN_EMPTY_IF_ERROR(_s) \ #define RETURN_EMPTY_IF_ERROR(_s) \
do { \ do { \
@@ -107,6 +108,24 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Shared pointer to the original object /// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);


/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
/// visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status Accept(NodePass *p, bool *modified);

/// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status AcceptAfter(NodePass *p, bool *modified);

protected: protected:
std::vector<std::shared_ptr<DatasetNode>> children; std::vector<std::shared_ptr<DatasetNode>> children;
std::shared_ptr<DatasetNode> parent; std::shared_ptr<DatasetNode> parent;


+ 432
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc View File

@@ -15,6 +15,56 @@
*/ */


#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#endif
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#endif
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
#include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h" #include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@@ -57,10 +107,391 @@
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/take_op.h" #include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h"
//////////////////////////////////


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


// Driver method for TreePass
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }

// Driver method for NodePass
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
if (root_ir == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
}
if (traversalOrder_ == Order::DFS) {
// DFS
return DFSNodeVisit(root_ir, modified);
} else if (traversalOrder_ == Order::BFS) {
// BFS
return BFSNodeVisit(root_ir, modified);
}
return Status::OK();
}

// Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
RETURN_IF_NOT_OK(node_ir->Accept(this, modified));
for (const auto &c : node_ir->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
}
return node_ir->AcceptAfter(this, modified);
}

// Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
// Initialize bfs queue with root
std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
bfsQueue.push(node_ir);

// BFS loop
while (!bfsQueue.empty()) {
// Pop the front of the bfs queue
auto curNode = bfsQueue.front();
bfsQueue.pop();

// Run node pass
RETURN_IF_NOT_OK(curNode->Accept(this, modified));

// Push children into bfs queue
for (const auto &c : curNode->Children()) {
bfsQueue.push(c);
}
}
return Status::OK();
}

// For datasetops IR
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

// For datasetops/source IR
Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Driver method for TreePass // Driver method for TreePass
Status TreePass::Run(ExecutionTree *tree, bool *modified) { Status TreePass::Run(ExecutionTree *tree, bool *modified) {
if (tree == nullptr || modified == nullptr) { if (tree == nullptr || modified == nullptr) {
@@ -320,5 +751,6 @@ Status NodePass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
} }
#endif #endif
//////////////////////////////////
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 247
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h View File

@@ -21,10 +21,61 @@
#include <queue> #include <queue>


#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class BatchNode;
class BucketBatchByLengthNode;
#ifndef ENABLE_ANDROID
class BuildSentenceVocabNode;
#endif
class BuildVocabNode;
class ConcatNode;
class MapNode;
class ProjectNode;
class RenameNode;
class RepeatNode;
class ShuffleNode;
class SkipNode;
#ifdef ENABLE_PYTHON
class SyncWaitNode;
#endif
class TakeNode;
class TransferNode;
class ZipNode;
class AlbumNode;
class CelebANode;
class Cifar100Node;
class Cifar10Node;
#ifndef ENABLE_ANDROID
class CLUENode;
#endif
class CocoNode;
#ifndef ENABLE_ANDROID
class CSVNode;
#endif
#ifdef ENABLE_PYTHON
class GeneratorNode;
#endif
class ImageFolderNode;
class ManifestNode;
#ifndef ENABLE_ANDROID
class MindDataNode;
#endif
class MnistNode;
class RandomNode;
#ifndef ENABLE_ANDROID
class TextFileNode;
#endif
#ifndef ENABLE_ANDROID
class TFRecordNode;
#endif
class VOCNode;

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
class BatchOp; class BatchOp;


class MapOp; class MapOp;
@@ -94,15 +145,24 @@ class FilterOp;


class GeneratorOp; class GeneratorOp;
#endif #endif
//////////////////////////////////


// The base class Pass is the basic unit of tree transformation. // The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here. // The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> { class Pass : public std::enable_shared_from_this<Pass> {
public: public:
// Run the transformation pass against the IR tree.
// @param root_ir - Pointer to the IR tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0;

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Run the transformation pass against the execution tree. // Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed. // @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag, // @param modified - Pointer to the modified flag,
virtual Status Run(ExecutionTree *tree, bool *modified) = 0; virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
//////////////////////////////////


virtual ~Pass() = default; virtual ~Pass() = default;
}; };
@@ -110,6 +170,13 @@ class Pass : public std::enable_shared_from_this<Pass> {
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. // TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class TreePass : public Pass { class TreePass : public Pass {
public: public:
/// \brief Run the transformation pass against the IR tree.
/// \param[inout] root_ir Pointer to the IR tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree. /// \brief Run the transformation pass against the execution tree.
/// \param[inout] tree Pointer to the execution tree to be transformed. /// \param[inout] tree Pointer to the execution tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified /// \param[inout] modified Indicate if the tree was modified
@@ -121,6 +188,7 @@ class TreePass : public Pass {
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return /// \return Status The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
//////////////////////////////////
}; };


// NodePass is a basic Pass class which performs transformation on Node visiting. // NodePass is a basic Pass class which performs transformation on Node visiting.
@@ -136,6 +204,175 @@ class NodePass : public Pass {


~NodePass() = default; ~NodePass() = default;


/// \brief Run the transformation pass against the IR tree
/// \param[inout] root_ir Pointer to the IR tree to be transformed
/// \param[inout] modified Indicator if the tree was changed
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;

/// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down
/// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }

/// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
/// "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }

// For datasetops IR
// Visit method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare Visit of its own type and override "Accept" from DatasetNode.
virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified);

// VisitAfter method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode.
virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified);

#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);

// For datasetops/source IR
virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified);

virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified);

virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified);
#endif

#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified);
#endif

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified);

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree /// \brief Run the transformation pass against the execution tree
/// \param[inout] tree Pointer to the execution tree to be transformed /// \param[inout] tree Pointer to the execution tree to be transformed
/// \param[inout] modified Indicator if the tree was changed /// \param[inout] modified Indicator if the tree was changed
@@ -241,13 +478,23 @@ class NodePass : public Pass {


virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
#endif #endif
//////////////////////////////////


private: private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);

// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Helper function to perform DFS visit // Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);


// Helper function to perform BFS visit // Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified); Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
//////////////////////////////////


// Tree traversal order of the NodePass // Tree traversal order of the NodePass
Order traversalOrder_; Order traversalOrder_;


Loading…
Cancel
Save