Merge pull request !3294 from nsyca/removal_passtags/v0.7.0-beta
| @@ -23,7 +23,7 @@ | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/injection_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||
| #include "minddata/dataset/engine/perf/profiling.h" | |||
| #include "minddata/dataset/engine/perf/monitor.h" | |||
| @@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() { | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running pre pass loops."; | |||
| pre_actions.push_back(std::make_unique<InjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| pre_actions.push_back(std::make_unique<CacheTransformPass>()); | |||
| // Apply pre action passes | |||
| @@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||
| add_library(engine-opt OBJECT | |||
| pass.cc | |||
| post/repeat_pass.cc | |||
| pre/cache_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/injection_pass.cc | |||
| pre/removal_nodes.cc | |||
| pre/epoch_injection_pass.cc | |||
| pre/removal_pass.cc | |||
| optional/tensor_op_fusion_pass.cc | |||
| util/printer_pass.cc | |||
| @@ -1,181 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/pre/cache_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CachePass::CachePass(CacheTransformPass *transform_pass) | |||
| : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||
| } | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||
| // transformation | |||
| Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||
| if (leaf_op_) { | |||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||
| // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, | |||
| // using base class pointers. | |||
| transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); | |||
| } else { | |||
| // If there was no leaf_op set, then this is a non-mappable scenario. | |||
| if (sampler_) { | |||
| // Grab the sampler that was saved from the leaf and plug it into the cache op | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; | |||
| } else { | |||
| // We're a cache op but no sampler was saved from leaf, so create a default sampler | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; | |||
| } | |||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||
| uint32_t cache_crc = DatasetOp::GenerateCRC(node); | |||
| RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for mappable leaf setup. | |||
| Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // If we are a leaf in the caching path, then save this leaf. | |||
| if (is_caching_) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||
| leaf_op_ = std::move(leaf_op); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for non mappable leaf setup. | |||
| Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf | |||
| // as save it for use by cache op in ascendant tree. | |||
| if (is_caching_) { | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||
| } else { | |||
| // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can | |||
| // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) | |||
| std::shared_ptr<Sampler> sampler_from_leaf; | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic | |||
| // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,141 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheTransformPass; | |||
| /// \class CachePass cache_pass.h | |||
| /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache | |||
| /// transformation. It works in conjunction with the CacheTransformPass | |||
| class CachePass : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] transform_pass Raw pointer back to controlling tree pass | |||
| explicit CachePass(CacheTransformPass *transform_pass); | |||
| /// \brief Destructor | |||
| ~CachePass() = default; | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||
| /// transformation | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| private: | |||
| /// \brief Common code for mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Common code for non-mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| bool is_caching_; | |||
| std::shared_ptr<DatasetOp> leaf_op_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ | |||
| @@ -15,17 +15,177 @@ | |||
| */ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pre/cache_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||
| } | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||
| // transformation | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||
| if (leaf_op_) { | |||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||
| // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, | |||
| // using base class pointers. | |||
| AddMappableCacheOperators(std::move(leaf_op_), node); | |||
| } else { | |||
| // If there was no leaf_op set, then this is a non-mappable scenario. | |||
| if (sampler_) { | |||
| // Grab the sampler that was saved from the leaf and plug it into the cache op | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; | |||
| } else { | |||
| // We're a cache op but no sampler was saved from leaf, so create a default sampler | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; | |||
| } | |||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||
| uint32_t cache_crc = DatasetOp::GenerateCRC(node); | |||
| RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for mappable leaf setup. | |||
| Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // If we are a leaf in the caching path, then save this leaf. | |||
| if (is_caching_) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||
| leaf_op_ = std::move(leaf_op); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for non mappable leaf setup. | |||
| Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf | |||
| // as save it for use by cache op in ascendant tree. | |||
| if (is_caching_) { | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||
| } else { | |||
| // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can | |||
| // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) | |||
| std::shared_ptr<Sampler> sampler_from_leaf; | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic | |||
| // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Assigns the leaf and cache operators that are involved in a cache transformation | |||
| void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<CacheOp> cache_op) { | |||
| cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); | |||
| } | |||
| // constructor | |||
| CacheTransformPass::CacheTransformPass() {} | |||
| @@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: Cache transform pass started."; | |||
| // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will | |||
| // use to execute a transform. | |||
| std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this); | |||
| RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); | |||
| CachePass cache_pass = CachePass(); | |||
| RETURN_IF_NOT_OK(cache_pass.Run(tree, modified)); | |||
| // Then, execute the transform for each pair | |||
| for (auto cache_pair : cache_pairs_) { | |||
| for (auto cache_pair : cache_pass.cache_pairs()) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | |||
| ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); | |||
| } | |||
| @@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share | |||
| return Status::OK(); | |||
| } | |||
| // Assigns the leaf and cache operators that are involved in a cache transformation | |||
| void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<CacheOp> cache_op) { | |||
| cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -33,6 +33,123 @@ class CacheClient; | |||
| /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching | |||
| /// operations | |||
| class CacheTransformPass : public TreePass { | |||
| /// \class CachePass | |||
| /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache | |||
| /// transformation. It works in conjunction with the CacheTransformPass | |||
| class CachePass : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] transform_pass Raw pointer back to controlling tree pass | |||
| CachePass(); | |||
| /// \brief Destructor | |||
| ~CachePass() = default; | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that | |||
| /// will be involved in a cache transformation | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| /// \brief Getter | |||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; } | |||
| private: | |||
| /// \brief Common code for mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Common code for non-mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||
| /// \param[in] leaf_op The leaf operator involved in the cache transform | |||
| /// \param[in] cache_op The cache operator involved in the cache transform | |||
| void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op); | |||
| bool is_caching_; | |||
| std::shared_ptr<DatasetOp> leaf_op_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| // The two operators that work together to establish the cache transform | |||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| CacheTransformPass(); | |||
| @@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass { | |||
| /// \return Status The error code return | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||
| /// \param[in] leaf_op The leaf operator involved in the cache transform | |||
| /// \param[in] cache_op The cache operator involved in the cache transform | |||
| void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op); | |||
| private: | |||
| /// \brief Helper function to execute the cache transformation. | |||
| /// | |||
| @@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass { | |||
| /// \return Status The error code return | |||
| Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | |||
| // The two operators that work together to establish the cache transform | |||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,7 +16,7 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/injection_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| @@ -25,64 +25,55 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // constructor | |||
| InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {} | |||
| EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {} | |||
| // Performs finder work for BuildVocabOp that has special rules about epoch control injection | |||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||
| if (injection_pass_) { | |||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||
| } | |||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| // Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection | |||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) { | |||
| if (injection_pass_) { | |||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||
| } | |||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, | |||
| bool *modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| // Temporary code to prevent the injection of epoch control when cache op is present | |||
| // Remove this code in cache op phase 2 | |||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| if (injection_pass_) { | |||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||
| } | |||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||
| // Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here. | |||
| injection_point_ = node->child(0); | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {} | |||
| EpochInjectionPass::EpochInjectionPass() {} | |||
| // Runs an injection pass to inject in operators needed at the pre pass stage | |||
| Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: Injection pass started."; | |||
| // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. | |||
| // The finder can make updates to the InjectionPass object. | |||
| InjectionPass::InjectionFinder finder(this); | |||
| finder.Run(tree, modified); | |||
| // The finder can make updates to the EpochInjectionPass object. | |||
| EpochInjectionPass::InjectionFinder finder(tree->root()); | |||
| RETURN_IF_NOT_OK(finder.Run(tree, modified)); | |||
| // The first injection logic is to check if we should inject the epoch control op as the root node. | |||
| // Do not inject the op if the number of epochs is 1. | |||
| int32_t num_epochs = tree->num_epochs(); | |||
| if (num_epochs != 1 && !epoch_ctrl_bypass_) { | |||
| std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point(); | |||
| if (num_epochs != 1 && epoch_inject_node != nullptr) { | |||
| std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | |||
| RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); | |||
| std::shared_ptr<DatasetOp> node = tree->root(); | |||
| if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) { | |||
| tree->root()->InsertAsParent(epoch_ctrl_op); | |||
| } else { | |||
| tree->root()->child(0)->InsertAsParent(epoch_ctrl_op); | |||
| } | |||
| epoch_inject_node->InsertAsParent(epoch_ctrl_op); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: Injection pass complete."; | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| @@ -26,10 +26,10 @@ namespace dataset { | |||
| class DatasetOp; | |||
| /// \class InjectionPass injection_pass.h | |||
| /// \class EpochInjectionPass epoch_injection_pass.h | |||
| /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api | |||
| /// parsing. | |||
| class InjectionPass : public TreePass { | |||
| class EpochInjectionPass : public TreePass { | |||
| /// \class InjectionFinder | |||
| /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for | |||
| /// operators that need to be injected. It is run first by the main injection pass to find out what operators | |||
| @@ -37,7 +37,10 @@ class InjectionPass : public TreePass { | |||
| class InjectionFinder : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit InjectionFinder(InjectionPass *injection_pass); | |||
| explicit InjectionFinder(std::shared_ptr<DatasetOp> node); | |||
| /// \brief Destructor | |||
| ~InjectionFinder() = default; | |||
| /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| @@ -58,24 +61,30 @@ class InjectionPass : public TreePass { | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Register the DeviceQueueOp for further action. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | |||
| /// \brief Getter | |||
| std::shared_ptr<DatasetOp> injection_point() { return injection_point_; } | |||
| private: | |||
| InjectionPass *injection_pass_; | |||
| std::shared_ptr<DatasetOp> injection_point_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| InjectionPass(); | |||
| EpochInjectionPass(); | |||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| private: | |||
| bool epoch_ctrl_bypass_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| @@ -1,58 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/pre/removal_nodes.h" | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree | |||
| Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; | |||
| is_caching_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Perform ShuffleOp removal check. | |||
| Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| *modified = false; | |||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||
| if (is_caching_) { | |||
| MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||
| if (removal_pass_) { | |||
| removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node)); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,64 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class RemovalNodes removal_nodes.h | |||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||
| /// It works in conjunction with the removal_pass. | |||
| class RemovalNodes : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||
| explicit RemovalNodes(RemovalPass *removal_pass); | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Destructor | |||
| ~RemovalNodes() = default; | |||
| /// \brief Perform ShuffleOp removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||
| private: | |||
| bool is_caching_; | |||
| RemovalPass *removal_pass_; // Back pointer to the owning removal pass | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ | |||
| @@ -16,32 +16,58 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/removal_nodes.h" | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree | |||
| Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; | |||
| is_caching_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Perform ShuffleOp removal check. | |||
| Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| *modified = false; | |||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||
| if (is_caching_) { | |||
| MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| RemovalPass::RemovalPass() {} | |||
| // Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| // Walk the tree to collect the nodes to remove, then removes them. | |||
| Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: removal pass started."; | |||
| // Create the removal node pass which can identify which nodes need to be removed. | |||
| std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this); | |||
| std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>(); | |||
| RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | |||
| // Then, execute the removal of any nodes that were set up for removal | |||
| for (auto node : removal_nodes_) { | |||
| for (auto node : removal_nodes->nodes_to_remove()) { | |||
| node->Remove(); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: removal pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the list of operators to be removed | |||
| void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -30,6 +30,45 @@ class DatasetOp; | |||
| /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | |||
| /// nodes should be removed, and then removes them. | |||
| class RemovalPass : public TreePass { | |||
| /// \class RemovalNodes | |||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||
| /// It works in conjunction with the removal_pass. | |||
| class RemovalNodes : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||
| RemovalNodes(); | |||
| /// \brief Destructor | |||
| ~RemovalNodes() = default; | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Perform ShuffleOp removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||
| /// \brief Getter | |||
| /// \return All the nodes to be removed | |||
| std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; } | |||
| private: | |||
| bool is_caching_; | |||
| std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| RemovalPass(); | |||
| @@ -42,13 +81,6 @@ class RemovalPass : public TreePass { | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| /// \brief Adds an operator to the list of operators to be removed | |||
| /// \param[in] dataset_op The operator to add to the removal list | |||
| void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op); | |||
| private: | |||
| std::vector<std::shared_ptr<DatasetOp>> removal_nodes_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards(): | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) | |||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id(): | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) | |||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) | |||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) | |||
| with pytest.raises(Exception) as error_info: | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) | |||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0(): | |||
| num_iter += 1 | |||
| with pytest.raises(Exception) as error_info: | |||
| partitions(5) | |||
| assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info) | |||
| assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||