Browse Source

Add optimizer to IR tree #1

tags/v1.1.0
Nat Sutyanyong 5 years ago
parent
commit
80d02d6dcd
4 changed files with 712 additions and 0 deletions
  1. +14
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  2. +19
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  3. +432
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
  4. +247
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.h

+ 14
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -20,6 +20,7 @@
#include <memory>
#include <set>

#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h"

namespace mindspore {
@@ -227,6 +228,7 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers;
return shared_from_this();
}

DatasetNode::DatasetNode() {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
@@ -236,5 +238,17 @@ DatasetNode::DatasetNode() {
worker_connector_size_ = cfg->worker_connector_size();
}

// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Status DatasetNode::Accept(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->Visit(shared_from_this(), modified);
}

// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
// after all child nodes are visited.
Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified);
}
} // namespace dataset
} // namespace mindspore

+ 19
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -31,6 +31,7 @@ namespace dataset {

class Dataset;
class SamplerObj;
class NodePass;

#define RETURN_EMPTY_IF_ERROR(_s) \
do { \
@@ -107,6 +108,24 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);

/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
/// visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status Accept(NodePass *p, bool *modified);

/// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status AcceptAfter(NodePass *p, bool *modified);

protected:
std::vector<std::shared_ptr<DatasetNode>> children;
std::shared_ptr<DatasetNode> parent;


+ 432
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc View File

@@ -15,6 +15,56 @@
*/

#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#endif
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#endif
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#ifndef ENABLE_ANDROID
@@ -57,10 +107,391 @@
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
//////////////////////////////////

namespace mindspore {
namespace dataset {

// Driver method for TreePass
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }

// Driver method for NodePass
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
if (root_ir == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
}
if (traversalOrder_ == Order::DFS) {
// DFS
return DFSNodeVisit(root_ir, modified);
} else if (traversalOrder_ == Order::BFS) {
// BFS
return BFSNodeVisit(root_ir, modified);
}
return Status::OK();
}

// Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
RETURN_IF_NOT_OK(node_ir->Accept(this, modified));
for (const auto &c : node_ir->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
}
return node_ir->AcceptAfter(this, modified);
}

// Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
// Initialize bfs queue with root
std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
bfsQueue.push(node_ir);

// BFS loop
while (!bfsQueue.empty()) {
// Pop the front of the bfs queue
auto curNode = bfsQueue.front();
bfsQueue.pop();

// Run node pass
RETURN_IF_NOT_OK(curNode->Accept(this, modified));

// Push children into bfs queue
for (const auto &c : curNode->Children()) {
bfsQueue.push(c);
}
}
return Status::OK();
}

// For datasetops IR
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

// For datasetops/source IR
Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Driver method for TreePass
Status TreePass::Run(ExecutionTree *tree, bool *modified) {
if (tree == nullptr || modified == nullptr) {
@@ -320,5 +751,6 @@ Status NodePass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
//////////////////////////////////
} // namespace dataset
} // namespace mindspore

+ 247
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h View File

@@ -21,10 +21,61 @@
#include <queue>

#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
class BatchNode;
class BucketBatchByLengthNode;
#ifndef ENABLE_ANDROID
class BuildSentenceVocabNode;
#endif
class BuildVocabNode;
class ConcatNode;
class MapNode;
class ProjectNode;
class RenameNode;
class RepeatNode;
class ShuffleNode;
class SkipNode;
#ifdef ENABLE_PYTHON
class SyncWaitNode;
#endif
class TakeNode;
class TransferNode;
class ZipNode;
class AlbumNode;
class CelebANode;
class Cifar100Node;
class Cifar10Node;
#ifndef ENABLE_ANDROID
class CLUENode;
#endif
class CocoNode;
#ifndef ENABLE_ANDROID
class CSVNode;
#endif
#ifdef ENABLE_PYTHON
class GeneratorNode;
#endif
class ImageFolderNode;
class ManifestNode;
#ifndef ENABLE_ANDROID
class MindDataNode;
#endif
class MnistNode;
class RandomNode;
#ifndef ENABLE_ANDROID
class TextFileNode;
#endif
#ifndef ENABLE_ANDROID
class TFRecordNode;
#endif
class VOCNode;

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
class BatchOp;

class MapOp;
@@ -94,15 +145,24 @@ class FilterOp;

class GeneratorOp;
#endif
//////////////////////////////////

// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
public:
// Run the transformation pass against the IR tree.
// @param root_ir - Pointer to the IR tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0;

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
//////////////////////////////////

virtual ~Pass() = default;
};
@@ -110,6 +170,13 @@ class Pass : public std::enable_shared_from_this<Pass> {
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class TreePass : public Pass {
public:
/// \brief Run the transformation pass against the IR tree.
/// \param[inout] root_ir Pointer to the IR tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree.
/// \param[inout] tree Pointer to the execution tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
@@ -121,6 +188,7 @@ class TreePass : public Pass {
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
//////////////////////////////////
};

// NodePass is a basic Pass class which performs transformation on Node visiting.
@@ -136,6 +204,175 @@ class NodePass : public Pass {

~NodePass() = default;

/// \brief Run the transformation pass against the IR tree
/// \param[inout] root_ir Pointer to the IR tree to be transformed
/// \param[inout] modified Indicator if the tree was changed
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;

/// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down
/// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }

/// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
/// "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }

// For datasetops IR
// Visit method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare Visit of its own type and override "Accept" from DatasetNode.
virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified);

// VisitAfter method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode.
virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified);

#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);

// For datasetops/source IR
virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified);

virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified);

virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified);
#endif

#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified);
#endif

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified);

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree
/// \param[inout] tree Pointer to the execution tree to be transformed
/// \param[inout] modified Indicator if the tree was changed
@@ -241,13 +478,23 @@ class NodePass : public Pass {

virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
#endif
//////////////////////////////////

private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);

// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);

// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
//////////////////////////////////

// Tree traversal order of the NodePass
Order traversalOrder_;


Loading…
Cancel
Save