Browse Source

!3294 Refactor opt/pre passes

Merge pull request !3294 from nsyca/removal_pass
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
df1300d9cb
13 changed files with 407 additions and 532 deletions
  1. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
  2. +1
    -3
      mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
  3. +0
    -181
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc
  4. +0
    -141
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h
  5. +164
    -10
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc
  6. +117
    -8
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h
  7. +26
    -35
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc
  8. +20
    -11
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h
  9. +0
    -58
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc
  10. +0
    -64
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h
  11. +33
    -7
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc
  12. +39
    -7
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h
  13. +5
    -5
      tests/ut/python/dataset/test_minddataset_exception.py

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc View File

@@ -23,7 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
@@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<InjectionPass>());
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes


+ 1
- 3
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt View File

@@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc
pre/epoch_injection_pass.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc
util/printer_pass.cc


+ 0
- 181
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc View File

@@ -1,181 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"

namespace mindspore {
namespace dataset {

// Constructor
CachePass::CachePass(CacheTransformPass *transform_pass)
: transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {}

// Identifies the subtree below this node as a cached descendant tree.
Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
}
is_caching_ = true;
return Status::OK();
}

// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_op_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node);
} else {
// If there was no leaf_op set, then this is a non-mappable scenario.

if (sampler_) {
// Grab the sampler that was saved from the leaf and plug it into the cache op
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
} else {
// We're a cache op but no sampler was saved from leaf, so create a default sampler
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
}

// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
}

return Status::OK();
}

// Common code for mappable leaf setup.
Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}

// If we are a leaf in the caching path, then save this leaf.
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
leaf_op_ = std::move(leaf_op);
}
return Status::OK();
}

// Common code for non mappable leaf setup.
Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}

// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
} else {
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std::shared_ptr<Sampler> sampler_from_leaf;
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
}
return Status::OK();
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
if (is_caching_) {
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
} // namespace dataset
} // namespace mindspore

+ 0
- 141
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h View File

@@ -1,141 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_

#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/opt/pass.h"

namespace mindspore {
namespace dataset {

class CacheTransformPass;

/// \class CachePass cache_pass.h
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class CachePass : public NodePass {
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
explicit CachePass(CacheTransformPass *transform_pass);

/// \brief Destructor
~CachePass() = default;

/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
/// transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;

private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);

/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);

bool is_caching_;
std::shared_ptr<DatasetOp> leaf_op_;
std::shared_ptr<Sampler> sampler_;
CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass
};

} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_

+ 164
- 10
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc View File

@@ -15,17 +15,177 @@
*/

#include <vector>
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"

namespace mindspore {
namespace dataset {

// Constructor
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {}

// Identifies the subtree below this node as a cached descendant tree.
Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
}
is_caching_ = true;
return Status::OK();
}

// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_op_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
AddMappableCacheOperators(std::move(leaf_op_), node);
} else {
// If there was no leaf_op set, then this is a non-mappable scenario.

if (sampler_) {
// Grab the sampler that was saved from the leaf and plug it into the cache op
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
} else {
// We're a cache op but no sampler was saved from leaf, so create a default sampler
int64_t num_samples = 0;
int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
}

// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
}

return Status::OK();
}

// Common code for mappable leaf setup.
Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}

// If we are a leaf in the caching path, then save this leaf.
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
leaf_op_ = std::move(leaf_op);
}
return Status::OK();
}

// Common code for non mappable leaf setup.
Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}

// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
} else {
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std::shared_ptr<Sampler> sampler_from_leaf;
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
}
return Status::OK();
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
if (is_caching_) {
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}

// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}

// constructor
CacheTransformPass::CacheTransformPass() {}

@@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this);
RETURN_IF_NOT_OK(cache_pass->Run(tree, modified));
CachePass cache_pass = CachePass();
RETURN_IF_NOT_OK(cache_pass.Run(tree, modified));

// Then, execute the transform for each pair
for (auto cache_pair : cache_pairs_) {
for (auto cache_pair : cache_pass.cache_pairs()) {
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
}
@@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share

return Status::OK();
}

// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}
} // namespace dataset
} // namespace mindspore

+ 117
- 8
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h View File

@@ -33,6 +33,123 @@ class CacheClient;
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
/// operations
class CacheTransformPass : public TreePass {
/// \class CachePass
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class CachePass : public NodePass {
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
CachePass();

/// \brief Destructor
~CachePass() = default;

/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Resets the tracking of the cache within the tree and assigns the operators that
/// will be involved in a cache transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;

/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;

/// \brief Getter
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; }

private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);

/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);

/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);

bool is_caching_;
std::shared_ptr<DatasetOp> leaf_op_;
std::shared_ptr<Sampler> sampler_;
// The two operators that work together to establish the cache transform
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
};

public:
/// \brief Constructor
CacheTransformPass();
@@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;

/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);

private:
/// \brief Helper function to execute the cache transformation.
///
@@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);

// The two operators that work together to establish the cache transform
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
};
} // namespace dataset
} // namespace mindspore


mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc → mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc View File

@@ -16,7 +16,7 @@

#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
@@ -25,64 +25,55 @@ namespace mindspore {
namespace dataset {

// constructor
InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {}
EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {}

// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}

// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node,
bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}

// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}

Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
// Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here.
injection_point_ = node->child(0);
return Status::OK();
}

// constructor
InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {}
EpochInjectionPass::EpochInjectionPass() {}

// Runs an injection pass to inject in operators needed at the pre pass stage
Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Injection pass started.";

// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the InjectionPass object.
InjectionPass::InjectionFinder finder(this);
finder.Run(tree, modified);
// The finder can make updates to the EpochInjectionPass object.
EpochInjectionPass::InjectionFinder finder(tree->root());
RETURN_IF_NOT_OK(finder.Run(tree, modified));

// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t num_epochs = tree->num_epochs();
if (num_epochs != 1 && !epoch_ctrl_bypass_) {
std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point();
if (num_epochs != 1 && epoch_inject_node != nullptr) {
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
std::shared_ptr<DatasetOp> node = tree->root();
if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) {
tree->root()->InsertAsParent(epoch_ctrl_op);
} else {
tree->root()->child(0)->InsertAsParent(epoch_ctrl_op);
}
epoch_inject_node->InsertAsParent(epoch_ctrl_op);
}

MS_LOG(INFO) << "Pre pass: Injection pass complete.";

mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h → mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_

#include <memory>
#include <vector>
@@ -26,10 +26,10 @@ namespace dataset {

class DatasetOp;

/// \class InjectionPass injection_pass.h
/// \class EpochInjectionPass epoch_injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class InjectionPass : public TreePass {
class EpochInjectionPass : public TreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
@@ -37,7 +37,10 @@ class InjectionPass : public TreePass {
class InjectionFinder : public NodePass {
public:
/// \brief Constructor
explicit InjectionFinder(InjectionPass *injection_pass);
explicit InjectionFinder(std::shared_ptr<DatasetOp> node);

/// \brief Destructor
~InjectionFinder() = default;

/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
@@ -58,24 +61,30 @@ class InjectionPass : public TreePass {
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Register the DeviceQueueOp for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;

/// \brief Getter
std::shared_ptr<DatasetOp> injection_point() { return injection_point_; }

private:
InjectionPass *injection_pass_;
std::shared_ptr<DatasetOp> injection_point_;
};

public:
/// \brief Constructor
InjectionPass();
EpochInjectionPass();

/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;

private:
bool epoch_ctrl_bypass_;
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_

+ 0
- 58
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc View File

@@ -1,58 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"

namespace mindspore {
namespace dataset {

RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}

// Identifies the subtree below this node as a cached descendant tree.
Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
return Status::OK();
}

// Resets the tracking of the cache within the tree
Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
is_caching_ = false;
return Status::OK();
}

// Perform ShuffleOp removal check.
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
if (removal_pass_) {
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 64
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h View File

@@ -1,64 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_

#include <memory>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"

namespace mindspore {
namespace dataset {
/// \class RemovalNodes removal_nodes.h
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public NodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit RemovalNodes(RemovalPass *removal_pass);

/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Destructor
~RemovalNodes() = default;

/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;

private:
bool is_caching_;
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
};

} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_

+ 33
- 7
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc View File

@@ -16,32 +16,58 @@

#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/execution_tree.h"

namespace mindspore {
namespace dataset {

RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}

// Identifies the subtree below this node as a cached descendant tree.
Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
return Status::OK();
}

// Resets the tracking of the cache within the tree
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
is_caching_ = false;
return Status::OK();
}

// Perform ShuffleOp removal check.
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node));
}
return Status::OK();
}

// constructor
RemovalPass::RemovalPass() {}

// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
// Walk the tree to collect the nodes to remove, then removes them.
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: removal pass started.";
// Create the removal node pass which can identify which nodes need to be removed.
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>();
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));

// Then, execute the removal of any nodes that were set up for removal
for (auto node : removal_nodes_) {
for (auto node : removal_nodes->nodes_to_remove()) {
node->Remove();
}
MS_LOG(INFO) << "Pre pass: removal pass complete.";
return Status::OK();
}

// Adds an operator to the list of operators to be removed
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
} // namespace dataset
} // namespace mindspore

+ 39
- 7
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h View File

@@ -30,6 +30,45 @@ class DatasetOp;
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
/// nodes should be removed, and then removes them.
class RemovalPass : public TreePass {
/// \class RemovalNodes
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public NodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
RemovalNodes();

/// \brief Destructor
~RemovalNodes() = default;

/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;

/// \brief Getter
/// \return All the nodes to be removed
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; }

private:
bool is_caching_;
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_;
};

public:
/// \brief Constructor
RemovalPass();
@@ -42,13 +81,6 @@ class RemovalPass : public TreePass {
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;

/// \brief Adds an operator to the list of operators to be removed
/// \param[in] dataset_op The operator to add to the removal list
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);

private:
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
};
} // namespace dataset
} // namespace mindspore


+ 5
- 5
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards():
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)

os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
@@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id():
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

@@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard():
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)

with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)

os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
@@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0():
num_iter += 1
with pytest.raises(Exception) as error_info:
partitions(5)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value)

os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

Loading…
Cancel
Save